Spaces:
Runtime error
Runtime error
| ################################################################################################# | |
| # | |
| # Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: BSD-3-Clause | |
| # | |
| # Redistribution and use in source and binary forms, with or without | |
| # modification, are permitted provided that the following conditions are met: | |
| # | |
| # 1. Redistributions of source code must retain the above copyright notice, this | |
| # list of conditions and the following disclaimer. | |
| # | |
| # 2. Redistributions in binary form must reproduce the above copyright notice, | |
| # this list of conditions and the following disclaimer in the documentation | |
| # and/or other materials provided with the distribution. | |
| # | |
| # 3. Neither the name of the copyright holder nor the names of its | |
| # contributors may be used to endorse or promote products derived from | |
| # this software without specific prior written permission. | |
| # | |
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
| # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
| # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| # | |
| ################################################################################################# | |
| """ | |
| Methods for layout swizzling | |
| """ | |
| from .layout import * | |
| def shiftr(a, s): | |
| return a >> s if s > 0 else shiftl(a, -s) | |
| def shiftl(a, s): | |
| return a << s if s > 0 else shiftr(a, -s) | |
| ## A generic Swizzle functor | |
| # 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx | |
| # ^--^ Base is the number of least-sig bits to keep constant | |
| # ^-^ ^-^ Bits is the number of bits in the mask | |
| # ^---------^ Shift is the distance to shift the YYY mask | |
| # (pos shifts YYY to the right, neg shifts YYY to the left) | |
| # | |
| # e.g. Given | |
| # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx | |
| # the result is | |
| # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY | |
| # | |
| class Swizzle: | |
| def __init__(self, bits, base, shift): | |
| assert bits >= 0 | |
| assert base >= 0 | |
| assert abs(shift) >= bits | |
| self.bits = bits | |
| self.base = base | |
| self.shift = shift | |
| bit_msk = (1 << bits) - 1 | |
| self.yyy_msk = bit_msk << (base + max(0,shift)) | |
| self.zzz_msk = bit_msk << (base - min(0,shift)) | |
| # operator () (transform integer) | |
| def __call__(self, offset): | |
| return offset ^ shiftr(offset & self.yyy_msk, self.shift) | |
| # Size of the domain | |
| def size(self): | |
| return 1 << (bits + base + abs(shift)) | |
| # Size of the codomain | |
| def cosize(self): | |
| return self.size() | |
| # print and str | |
| def __str__(self): | |
| return f"SW_{self.bits}_{self.base}_{self.shift}" | |
| # error msgs and representation | |
| def __repr__(self): | |
| return f"Swizzle({self.bits},{self.base},{self.shift})" | |
| class ComposedLayout(LayoutBase): | |
| def __init__(self, layoutB, offset, layoutA): | |
| self.layoutB = layoutB | |
| self.offset = offset | |
| self.layoutA = layoutA | |
| # operator == | |
| def __eq__(self, other): | |
| return self.layoutB == other.layoutB and self.offset == other.offset and self.layoutA == other.layoutA | |
| # operator len(L) (len [rank] like tuples) | |
| def __len__(self): | |
| return len(self.layoutA) | |
| # operator () (map coord to idx) | |
| def __call__(self, *args): | |
| return self.layoutB(self.offset + self.layoutA(*args)) | |
| # operator [] (get-i like tuples) | |
| def __getitem__(self, i): | |
| return ComposedLayout(self.layoutB, self.offset, self.layoutA[i]) | |
| # size(layout) Size of the domain | |
| def size(self): | |
| return size(self.layoutA) | |
| # cosize(layout) Size of the codomain | |
| def cosize(self): | |
| return cosize(self.layoutB) | |
| # print and str | |
| def __str__(self): | |
| return f"{self.layoutB} o {self.offset} o {self.layoutA}" | |
| # error msgs and representation | |
| def __repr__(self): | |
| return f"ComposedLayout({repr(self.layoutB)},{repr(self.offset)},{repr(self.layoutA)})" | |