File size: 11,244 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Tuple, Union

from lark import Lark, Token, Tree

if TYPE_CHECKING:
    from tensorrt_llm.tools.plugin_gen.core import Argument

parser = Lark(r"""
value: SIGNED_NUMBER
      | name
      | expr
      | "(" expr ")"

expr: value "+" value -> add
    | value "-" value -> sub
    | value "*" value -> mul
    | value "/" value -> div
    | value

shaped_tensor: name "[" value ("," value)* ("," "*")? "]" -> tensor
      | name "[" "*" "]" -> wildcard_tensor

tensors: shaped_tensor ("," shaped_tensor)*

deduce_shape: tensors "->" tensors

deduce_dim_size_arg: tensors ":" expr "->" name

name: CNAME
?start: deduce_shape | deduce_dim_size_arg

%import common.SIGNED_NUMBER
%import common.WS
%import common.CNAME
%ignore WS
""".strip())


# Here we introduce a set of ASTs to represent the target's expression.
# The Ast nodes from lark is not convenient to use.
class _AST:
    pass


@dataclass
class NumberAST(_AST):
    value: int


@dataclass
class BinaryAST(_AST):
    op: str
    left: _AST
    right: _AST


@dataclass
class ShapeAST:
    dims: List[_AST]


@dataclass
class DimAST(_AST):
    name: str


@dataclass
class ShapedTensorAST(_AST):
    arg_name: str
    shape: ShapeAST


@dataclass
class DeduceShapeRule(_AST):
    left: List[ShapedTensorAST]
    right: List[ShapedTensorAST]


@dataclass
class DeduceDimSizeArgRule(_AST):
    left: List[ShapedTensorAST]
    expr: _AST
    right: str


class ToAst:

    def __call__(self,
                 tree: Tree) -> Union[DeduceShapeRule, DeduceDimSizeArgRule]:
        if tree.data == "deduce_shape":
            assert len(tree.children) == 2
            return self.visit_DeduceShape(tree.children[0], tree.children[1])
        elif tree.data == "deduce_dim_size_arg":
            assert len(tree.children) == 3
            return self.visit_DeduceDimSizeArg(tree.children[0],
                                               tree.children[1],
                                               tree.children[2])
        raise NotImplementedError()

    def visit_DeduceShape(self, left: Tree, right: Tree) -> DeduceShapeRule:
        assert left.data == "tensors"
        assert right.data == "tensors"

        lefts = self.visit_tensors(left)
        rights = self.visit_tensors(right)
        return DeduceShapeRule(lefts, rights)

    def visit_DeduceDimSizeArg(self, left: Tree, expr: Tree,
                               right: Tree) -> DeduceDimSizeArgRule:
        lefts = self.visit_tensors(left)
        _expr = self.visit_expr(expr)
        rights = self.visit_name(right)
        return DeduceDimSizeArgRule(lefts, _expr, rights)

    def visit_tensors(self, tree: Tree) -> List[ShapedTensorAST]:
        assert tree.data == "tensors", repr(tree)
        return [self.visit_tensor(child) for child in tree.children]

    def visit_tensor(self, tree: Tree) -> ShapedTensorAST:
        if tree.data == "tensor":
            arg_name = self.visit_name(tree.children[0])
            dims = [self.visit_expr(child) for child in tree.children[1:]]
            return ShapedTensorAST(arg_name, ShapeAST(dims))

        assert tree.data == "wildcard_tensor", repr(tree)
        arg_name = self.visit_name(tree.children[0])
        return ShapedTensorAST(arg_name, ShapeAST([DimAST("*")]))

    def visit_number(self, v: str) -> _AST:
        return NumberAST(int(v))

    def visit_expr(self, tree: Tree) -> _AST:
        '''
        for expression of dims, like `m * 2 + 1`
        '''

        def visit(tree: Union[Tree, Token]) -> _AST:
            if isinstance(tree, Token):
                if tree.type == "SIGNED_NUMBER":
                    return NumberAST(int(tree.value))
                elif tree.type == "CNAME":
                    return DimAST(tree.value)
                raise ValueError("Unexpected token: %s" % tree)

            elif isinstance(tree.data, Token):  # RULE; CNAME
                tree_type = tree.data.value
                if tree_type == 'name':
                    return DimAST(tree.children[0].value)
                elif tree_type == 'value':
                    return visit(tree.children[0])
                elif tree_type == 'expr':
                    return visit(tree.children[0])
                elif tree.data == "SIGNED_NUMBER":
                    return NumberAST(int(tree.children[0].data))
                else:
                    raise ValueError(f"Unexpected tree: {repr(tree)}")

            elif tree.data == "add":
                assert len(tree.children) == 2
                return BinaryAST("+", visit(tree.children[0]),
                                 visit(tree.children[1]))
            elif tree.data == "sub":
                assert len(tree.children) == 2
                return BinaryAST("-", visit(tree.children[0]),
                                 visit(tree.children[1]))
            elif tree.data == "mul":
                assert len(tree.children) == 2
                return BinaryAST("*", visit(tree.children[0]),
                                 visit(tree.children[1]))
            elif tree.data == "div":
                assert len(tree.children) == 2
                return BinaryAST("/", visit(tree.children[0]),
                                 visit(tree.children[1]))
            else:
                raise ValueError(f"Unexpected tree: {repr(tree)}")

        return visit(tree)

    def visit_name(self, tree: Tree) -> str:
        assert isinstance(tree.data, Token) and tree.data.value == "name"
        return tree.children[0].value


@dataclass
class Dim:
    arg: "Argument"
    dim_off: int


@dataclass
class CppCodeTranspiler:
    # The mapping from a arg_name in the expression to the corresponding Argument.
    name_to_arg: Dict[str, "Argument"]

    # The mapping from a dim_name in the expression to the corresponding Dim in an Argument.
    name_to_dim: Dict[str, Dim] = field(default_factory=dict, init=False)

    def __call__(self, exprs: List[str]) -> Tuple[List[str], Dict[str, str]]:
        asts = [self.to_ast(expr) for expr in exprs]
        return self.codegen(asts)

    def to_ast(self, expr: str) -> _AST:
        self.cur_expr = expr
        ast = parser.parse(expr)
        ast = ToAst()(ast)
        return ast

    def codegen(self, asts: List[_AST]) -> Tuple[List[str], Dict[str, str]]:
        '''
        Parse an expression group and generate the corresponding C++ code.

        The syntax of an expression is like below:

        - `name[expr, expr, ...] -> name[expr, expr, ...]`
        - `name[expr, expr, ...]:expr -> dim_arg`
        '''
        shape_infer_code = []
        dim_size_infer_code = {}

        for ast in asts:
            if isinstance(ast, DeduceShapeRule):
                self.dim_cpp_repr = lambda arg_idx, dim_idx: f"inputDims[{arg_idx}].d[{dim_idx}]"
                shape_infer_code.extend(self.emit_DeduceShapeRule(ast))
            elif isinstance(ast, DeduceDimSizeArgRule):
                self.dim_cpp_repr = lambda arg_idx, dim_idx: f"inputDesc[{arg_idx}].dims.d[{dim_idx}]"
                dim_size_infer_code[ast.right] = self.emit_DeduceDimSizeArgRule(
                    ast)
            else:
                raise ValueError("Unexpected ast: %s" % repr(ast))

        return shape_infer_code, dim_size_infer_code

    @staticmethod
    def is_cur_identical_dims(item: ShapedTensorAST):
        return len(item.shape.dims) == 1 and isinstance(
            item.shape.dims[0], DimAST) and item.shape.dims[0].name == "*"

    def collect_dims_from_left(self, lefts: List[ShapedTensorAST]):
        self.name_to_dim.clear()

        is_left_identical_dims = self.is_cur_identical_dims(lefts[0])
        # process left, and record the named dimensions
        for left in lefts:
            arg_name = left.arg_name
            argument = self.name_to_arg[arg_name]
            for off, dim in enumerate(left.shape.dims):
                assert isinstance(
                    dim, DimAST
                ), f"Wrong syntax in '{self.cur_expr}', for deduce_shape rule, each named dimension should be a name rather than an expression"
                self.name_to_dim[dim.name] = Dim(argument, off)
        return is_left_identical_dims

    def emit_DeduceShapeRule(self, rule: DeduceShapeRule) -> List[str]:
        from tensorrt_llm.tools.plugin_gen.core import code

        is_cur_identical_dims = lambda item: len(
            item.shape.dims) == 1 and isinstance(item.shape.dims[
                0], DimAST) and item.shape.dims[0].name == "*"

        is_left_identical_dims = self.collect_dims_from_left(rule.left)

        first_left_tensor = rule.left[0]
        first_left_tensor_arg = self.name_to_arg[first_left_tensor.arg_name]

        ret = []
        # process right, and generate the code for each dimensions

        # TODO: support more wildcard cases, currently only A[*] -> B[*], C[*] is supported
        is_right_identical_dims = False
        for off, item in enumerate(rule.right):
            is_cur_identical_dims = self.is_cur_identical_dims(item)
            if is_right_identical_dims and not is_cur_identical_dims:
                assert is_cur_identical_dims, "Wrong syntax in '%s', for deduce_shape rule, once the left side be X[*], the should all be X[*] format too" % self.cur_expr
            is_right_identical_dims = is_cur_identical_dims

        assert is_left_identical_dims == is_right_identical_dims, "Wrong syntax in '%s', for deduce_shape rule, the left and right side should be both X[*] or not" % self.cur_expr

        for off, tensor in enumerate(rule.right):
            out_arg = self.name_to_arg[tensor.arg_name]
            ret.append(code(f"if (outputIndex == {out_arg.offset}) {{"))

            if is_right_identical_dims:
                ret.append(
                    code(
                        f"  outputDims = inputDims[{first_left_tensor_arg.offset}];"
                    ))
            else:
                ret.append(
                    code(f"  outputDims.nbDims = {len(tensor.shape.dims)};"))
                for dim_off, dim in enumerate(tensor.shape.dims):
                    ret.append(
                        code(
                            f"  outputDims.d[{dim_off}] = {self.emit_expr(dim)};"
                        ))

            ret.append(code(f"}}"))

        return ret

    def emit_DeduceDimSizeArgRule(self, rule: DeduceDimSizeArgRule) -> str:
        self.collect_dims_from_left(rule.left)
        return self.emit_expr(rule.expr)

    def emit_expr(self, expr: _AST) -> str:
        if isinstance(expr, NumberAST):
            return str(expr.value)
        elif isinstance(expr, DimAST):
            return self.emit_dim(expr)
        elif isinstance(expr, BinaryAST):
            return self.emit_binary(expr)
        raise ValueError("Unexpected expr: %s" % expr)

    def emit_dim(self, dim: DimAST) -> str:
        dim_: Dim = self.name_to_dim[dim.name]
        repr = self.dim_cpp_repr(dim_.arg.offset, dim_.dim_off)
        return repr

    def emit_binary(self, binary: BinaryAST) -> str:
        left = self.emit_expr(binary.left)
        right = self.emit_expr(binary.right)
        return f"({left} {binary.op} {right})"