Spaces:
Runtime error
Runtime error
| ################################################################################################# | |
| # | |
| # Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: BSD-3-Clause | |
| # | |
| # Redistribution and use in source and binary forms, with or without | |
| # modification, are permitted provided that the following conditions are met: | |
| # | |
| # 1. Redistributions of source code must retain the above copyright notice, this | |
| # list of conditions and the following disclaimer. | |
| # | |
| # 2. Redistributions in binary form must reproduce the above copyright notice, | |
| # this list of conditions and the following disclaimer in the documentation | |
| # and/or other materials provided with the distribution. | |
| # | |
| # 3. Neither the name of the copyright holder nor the names of its | |
| # contributors may be used to endorse or promote products derived from | |
| # this software without specific prior written permission. | |
| # | |
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
| # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
| # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| # | |
| ################################################################################################# | |
| """ | |
| Definition of CuTe Layouts and functions to manipulate them | |
| """ | |
| from itertools import chain | |
| from typing import Union | |
| from .int_tuple import * | |
| class LayoutBase: | |
| pass | |
| def is_layout(x): | |
| return isinstance(x, LayoutBase) | |
| class Layout(LayoutBase): | |
| def __init__(self, _shape, _stride=None): | |
| self.shape = _shape | |
| if _stride is None: | |
| self.stride = prefix_product(self.shape) | |
| else: | |
| self.stride = _stride | |
| # operator == | |
| def __eq__(self, other): | |
| return self.shape == other.shape and self.stride == other.stride | |
| # operator len(L) (len [rank] like tuples) | |
| def __len__(self): | |
| if is_tuple(self.shape): | |
| return len(self.shape) | |
| else: | |
| return 1 | |
| # operator () (map coord to idx) | |
| def __call__(self, *args): | |
| """ | |
| Map a logical coordinate to a linear index (Coord has no Underscore slice operators) | |
| OR | |
| Slice the layout and return the sublayout (Coord has an Underscore slice op) | |
| Follow the same behavior of `Layout::operator(Coord const&)` in cute C++ | |
| """ | |
| if has_none(args): | |
| if len(args) == 1: | |
| return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride)) | |
| else: | |
| return Layout(slice_(args, self.shape), slice_(args, self.stride)) | |
| else: | |
| if len(args) == 1: | |
| return crd2idx(args[0], self.shape, self.stride) | |
| else: | |
| return crd2idx(args, self.shape, self.stride) | |
| # operator [] (get-i like tuples) | |
| def __getitem__(self, i): | |
| if is_tuple(self.shape): | |
| return Layout(self.shape[i], self.stride[i]) | |
| else: | |
| assert i == 0 | |
| return Layout(self.shape, self.stride) | |
| # size(layout) Size of the domain | |
| def size(self): | |
| return product(self.shape) | |
| # cosize(layout) Size of the codomain | |
| def cosize(self): | |
| return self(self.size() - 1) + 1 | |
| # print and str | |
| def __str__(self): | |
| return f"{self.shape}:{self.stride}" | |
| # error msgs and representation | |
| def __repr__(self): | |
| return f"Layout({self.shape},{self.stride})" | |
| # Make Layout from a list of layouts (each layout it's own mode in the result) | |
| def make_layout(*layouts): | |
| if len(layouts) == 1 and not is_layout(layouts[0]): | |
| layouts = layouts[0] | |
| shape, stride = zip(*((a.shape,a.stride) for a in layouts)) | |
| return Layout(shape, stride) | |
| # Size of the domain | |
| def size(layout): | |
| if is_layout(layout): | |
| return layout.size() | |
| return product(layout) | |
| # Size of the codomain | |
| def cosize(layout): | |
| return layout.cosize() | |
| # Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function | |
| def coalesce(layout, profile=None): | |
| if is_tuple(profile): | |
| assert len(layout) >= len(profile) | |
| return make_layout(chain((coalesce(layout[i], profile[i]) for i in range( 0,len(profile))), | |
| (layout[i] for i in range(len(profile),len(layout))))) | |
| result_shape = [1] | |
| result_stride = [0] | |
| for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): | |
| # skip their shape-1s | |
| if shape == 1: | |
| continue | |
| # replace our shape-1 with anything | |
| elif result_shape[-1] == 1: | |
| result_shape[-1] = shape | |
| result_stride[-1] = stride | |
| # merge modes if the shape*stride match | |
| elif result_shape[-1] * result_stride[-1] == stride: | |
| result_shape[-1] = result_shape[-1] * shape | |
| # append a new mode | |
| else: | |
| result_shape.append(shape) | |
| result_stride.append(stride) | |
| if len(result_shape) == 1: | |
| return Layout(result_shape[0], result_stride[0]) | |
| else: | |
| return Layout(tuple(result_shape), tuple(result_stride)) | |
| # Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them | |
| def filter(layout, profile=None): | |
| if is_tuple(profile): | |
| assert len(layout) >= len(profile) | |
| return make_layout(chain((filter(layout[i], profile[i]) for i in range( 0,len(profile))), | |
| (layout[i] for i in range(len(profile),len(layout))))) | |
| result_shape = [] | |
| result_stride = [] | |
| for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): | |
| # skip their shape-1s and stride-0s | |
| if not (shape == 1 or stride == 0): | |
| result_shape.append(shape) | |
| result_stride.append(stride) | |
| if len(result_shape) == 0: | |
| return Layout(1,0) | |
| else: | |
| return coalesce(Layout(tuple(result_shape), tuple(result_stride))) | |
| # Layout composition | |
| # Use tuples-of-layouts to perform this operation by-mode and None as no-op | |
| def composition(layoutA, layoutB): | |
| if layoutB is None: | |
| return layoutA | |
| elif is_int(layoutB): | |
| return composition(layoutA, Layout(layoutB)) | |
| elif is_tuple(layoutB): | |
| assert len(layoutA) >= len(layoutB) | |
| return make_layout(chain((composition(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), | |
| (layoutA[i] for i in range(len(layoutB),len(layoutA))))) | |
| elif is_tuple(layoutB.shape): | |
| return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB) | |
| if layoutB.stride == 0: | |
| return Layout(layoutB.shape, 0) | |
| else: | |
| result_shape = [] | |
| result_stride = [] | |
| rest_shape = layoutB.shape | |
| rest_stride = layoutB.stride | |
| for (s, d) in zip(flatten(layoutA.shape)[:-1], flatten(layoutA.stride)[:-1]): | |
| s1 = shape_div(s, rest_stride) | |
| result_shape.append(min(s1,rest_shape)) | |
| result_stride.append(rest_stride * d) | |
| rest_shape = shape_div(rest_shape, abs(s1)) | |
| rest_stride = shape_div(rest_stride, s) | |
| result_shape.append(rest_shape) | |
| result_stride.append(rest_stride * flatten(layoutA.stride)[-1]) | |
| return coalesce(Layout(tuple(result_shape), tuple(result_stride))) | |
| # Layout complement | |
| def complement(layout, max_idx=1): | |
| if is_int(layout): | |
| return complement(Layout(layout)) | |
| result_shape = [] | |
| result_stride = [] | |
| current_idx = 1 | |
| sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape))) | |
| for (stride, shape) in sorted_DS: | |
| if stride == 0 or shape == 1: | |
| continue | |
| in_bound = current_idx <= shape * stride | |
| # To support symbolic value which can't be evaluated now | |
| assert (type(in_bound) is not bool) or in_bound | |
| result_shape.append(stride // current_idx) | |
| result_stride.append(current_idx) | |
| current_idx = shape * stride | |
| result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div | |
| result_stride.append(current_idx) | |
| return coalesce(Layout(tuple(result_shape), tuple(result_stride))) | |
| # Layout right inverse | |
| def right_inverse(layout): | |
| if layout is None: | |
| return None | |
| elif is_int(layout): | |
| return Layout(layout) | |
| result_shape = [] | |
| result_stride = [] | |
| current_idx = 1 | |
| flat_shape = flatten(layout.shape) | |
| flat_stride = flatten(layout.stride) | |
| sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape))) | |
| for (stride,shape,rstride) in sorted_DSA: | |
| if shape == 1: | |
| continue | |
| if current_idx != stride: | |
| break | |
| result_shape.append(shape) | |
| result_stride.append(rstride) | |
| current_idx = shape * stride | |
| return coalesce(Layout(tuple(result_shape), tuple(result_stride))) | |
| # Layout left inverse | |
| def left_inverse(layout): | |
| if layout is None: | |
| return None | |
| elif is_int(layout): | |
| return Layout(layout) | |
| return right_inverse(make_layout(layout, complement(layout))) | |
| # Split a layout by the composition of B and the "rest" | |
| # Use tuples-of-layouts to perform this operation by-mode and None as no-op | |
| def logical_divide(layoutA, layoutB): | |
| if layoutB is None: | |
| return layoutA | |
| elif is_int(layoutB): | |
| return logical_divide(layoutA, Layout(layoutB)) | |
| elif is_tuple(layoutB): | |
| assert len(layoutA) >= len(layoutB) | |
| return make_layout(chain((logical_divide(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), | |
| (layoutA[i] for i in range(len(layoutB),len(layoutA))))) | |
| return composition(layoutA, make_layout(layoutB, complement(layoutB, size(layoutA)))) | |
| # Reproduce a layoutA over a layoutB | |
| # Use tuples-of-layouts to perform this operation by-mode and None as no-op | |
| def logical_product(layoutA, layoutB): | |
| if layoutB is None: | |
| return layoutA | |
| elif is_int(layoutB): | |
| return logical_divide(layoutA, Layout(layoutB)) | |
| elif is_tuple(layoutB): | |
| assert len(layoutA) >= len(layoutB) | |
| return make_layout(chain((logical_product(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), | |
| (layoutA[i] for i in range(len(layoutB),len(layoutA))))) | |
| return make_layout(layoutA, composition(complement(layoutA, size(layoutA)*cosize(layoutB)), layoutB)); | |
| # Gather the modes from a hierarchical logical_divide or logical_product | |
| def hier_unzip(splitter, layoutA, layoutB): | |
| if layoutB is None: | |
| return make_layout(Layout(1,0), layoutA) | |
| elif is_tuple(layoutB): | |
| assert len(layoutA) >= len(layoutB) | |
| # A layout with shape ((A,a),(B,b),(C,c)) | |
| split = make_layout(hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0,len(layoutB))) | |
| # Gather to shape ((A,B,C,...),(a,b,c,...,y,z)) | |
| return make_layout(make_layout( split[i][0] for i in range( 0,len(layoutB))), | |
| make_layout(chain((split[i][1] for i in range( 0,len(layoutB))), | |
| (layoutA[i] for i in range(len(layoutB),len(layoutA)))))) | |
| # splitter must return a rank-2 layout | |
| return splitter(layoutA, layoutB) | |
| # Apply logical divide hierarchically and gather the split modes into two modes | |
| def zipped_divide(layoutA, layoutB): | |
| return hier_unzip(logical_divide, layoutA, layoutB) | |
| # Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode | |
| def tiled_divide(layoutA, layoutB): | |
| result = zipped_divide(layoutA, layoutB) | |
| return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) | |
| # Apply logical product hierarchically and gather the split modes into two modes | |
| def zipped_product(layoutA, layoutB): | |
| return hier_unzip(logical_product, layoutA, layoutB) | |
| # Perform logical product hierarchically and gather tiles (B-layouts) into a new mode | |
| def tiled_product(layoutA, layoutB): | |
| result = zipped_product(layoutA, layoutB) | |
| return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) | |
| def slice_and_offset(crd: tuple, | |
| layout: Layout): | |
| return (Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)), | |
| crd2idx(crd, layout.shape, layout.stride)) | |