|
|
from dataclasses import dataclass, field |
|
|
from typing import TYPE_CHECKING, Dict, List, Tuple, Union |
|
|
|
|
|
from lark import Lark, Token, Tree |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from tensorrt_llm.tools.plugin_gen.core import Argument |
|
|
|
|
|
parser = Lark(r""" |
|
|
value: SIGNED_NUMBER |
|
|
| name |
|
|
| expr |
|
|
| "(" expr ")" |
|
|
|
|
|
expr: value "+" value -> add |
|
|
| value "-" value -> sub |
|
|
| value "*" value -> mul |
|
|
| value "/" value -> div |
|
|
| value |
|
|
|
|
|
shaped_tensor: name "[" value ("," value)* ("," "*")? "]" -> tensor |
|
|
| name "[" "*" "]" -> wildcard_tensor |
|
|
|
|
|
tensors: shaped_tensor ("," shaped_tensor)* |
|
|
|
|
|
deduce_shape: tensors "->" tensors |
|
|
|
|
|
deduce_dim_size_arg: tensors ":" expr "->" name |
|
|
|
|
|
name: CNAME |
|
|
?start: deduce_shape | deduce_dim_size_arg |
|
|
|
|
|
%import common.SIGNED_NUMBER |
|
|
%import common.WS |
|
|
%import common.CNAME |
|
|
%ignore WS |
|
|
""".strip()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _AST: |
|
|
pass |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class NumberAST(_AST): |
|
|
value: int |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BinaryAST(_AST): |
|
|
op: str |
|
|
left: _AST |
|
|
right: _AST |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ShapeAST: |
|
|
dims: List[_AST] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DimAST(_AST): |
|
|
name: str |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ShapedTensorAST(_AST): |
|
|
arg_name: str |
|
|
shape: ShapeAST |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DeduceShapeRule(_AST): |
|
|
left: List[ShapedTensorAST] |
|
|
right: List[ShapedTensorAST] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DeduceDimSizeArgRule(_AST): |
|
|
left: List[ShapedTensorAST] |
|
|
expr: _AST |
|
|
right: str |
|
|
|
|
|
|
|
|
class ToAst: |
|
|
|
|
|
def __call__(self, |
|
|
tree: Tree) -> Union[DeduceShapeRule, DeduceDimSizeArgRule]: |
|
|
if tree.data == "deduce_shape": |
|
|
assert len(tree.children) == 2 |
|
|
return self.visit_DeduceShape(tree.children[0], tree.children[1]) |
|
|
elif tree.data == "deduce_dim_size_arg": |
|
|
assert len(tree.children) == 3 |
|
|
return self.visit_DeduceDimSizeArg(tree.children[0], |
|
|
tree.children[1], |
|
|
tree.children[2]) |
|
|
raise NotImplementedError() |
|
|
|
|
|
def visit_DeduceShape(self, left: Tree, right: Tree) -> DeduceShapeRule: |
|
|
assert left.data == "tensors" |
|
|
assert right.data == "tensors" |
|
|
|
|
|
lefts = self.visit_tensors(left) |
|
|
rights = self.visit_tensors(right) |
|
|
return DeduceShapeRule(lefts, rights) |
|
|
|
|
|
def visit_DeduceDimSizeArg(self, left: Tree, expr: Tree, |
|
|
right: Tree) -> DeduceDimSizeArgRule: |
|
|
lefts = self.visit_tensors(left) |
|
|
_expr = self.visit_expr(expr) |
|
|
rights = self.visit_name(right) |
|
|
return DeduceDimSizeArgRule(lefts, _expr, rights) |
|
|
|
|
|
def visit_tensors(self, tree: Tree) -> List[ShapedTensorAST]: |
|
|
assert tree.data == "tensors", repr(tree) |
|
|
return [self.visit_tensor(child) for child in tree.children] |
|
|
|
|
|
def visit_tensor(self, tree: Tree) -> ShapedTensorAST: |
|
|
if tree.data == "tensor": |
|
|
arg_name = self.visit_name(tree.children[0]) |
|
|
dims = [self.visit_expr(child) for child in tree.children[1:]] |
|
|
return ShapedTensorAST(arg_name, ShapeAST(dims)) |
|
|
|
|
|
assert tree.data == "wildcard_tensor", repr(tree) |
|
|
arg_name = self.visit_name(tree.children[0]) |
|
|
return ShapedTensorAST(arg_name, ShapeAST([DimAST("*")])) |
|
|
|
|
|
def visit_number(self, v: str) -> _AST: |
|
|
return NumberAST(int(v)) |
|
|
|
|
|
def visit_expr(self, tree: Tree) -> _AST: |
|
|
''' |
|
|
for expression of dims, like `m * 2 + 1` |
|
|
''' |
|
|
|
|
|
def visit(tree: Union[Tree, Token]) -> _AST: |
|
|
if isinstance(tree, Token): |
|
|
if tree.type == "SIGNED_NUMBER": |
|
|
return NumberAST(int(tree.value)) |
|
|
elif tree.type == "CNAME": |
|
|
return DimAST(tree.value) |
|
|
raise ValueError("Unexpected token: %s" % tree) |
|
|
|
|
|
elif isinstance(tree.data, Token): |
|
|
tree_type = tree.data.value |
|
|
if tree_type == 'name': |
|
|
return DimAST(tree.children[0].value) |
|
|
elif tree_type == 'value': |
|
|
return visit(tree.children[0]) |
|
|
elif tree_type == 'expr': |
|
|
return visit(tree.children[0]) |
|
|
elif tree.data == "SIGNED_NUMBER": |
|
|
return NumberAST(int(tree.children[0].data)) |
|
|
else: |
|
|
raise ValueError(f"Unexpected tree: {repr(tree)}") |
|
|
|
|
|
elif tree.data == "add": |
|
|
assert len(tree.children) == 2 |
|
|
return BinaryAST("+", visit(tree.children[0]), |
|
|
visit(tree.children[1])) |
|
|
elif tree.data == "sub": |
|
|
assert len(tree.children) == 2 |
|
|
return BinaryAST("-", visit(tree.children[0]), |
|
|
visit(tree.children[1])) |
|
|
elif tree.data == "mul": |
|
|
assert len(tree.children) == 2 |
|
|
return BinaryAST("*", visit(tree.children[0]), |
|
|
visit(tree.children[1])) |
|
|
elif tree.data == "div": |
|
|
assert len(tree.children) == 2 |
|
|
return BinaryAST("/", visit(tree.children[0]), |
|
|
visit(tree.children[1])) |
|
|
else: |
|
|
raise ValueError(f"Unexpected tree: {repr(tree)}") |
|
|
|
|
|
return visit(tree) |
|
|
|
|
|
def visit_name(self, tree: Tree) -> str: |
|
|
assert isinstance(tree.data, Token) and tree.data.value == "name" |
|
|
return tree.children[0].value |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Dim: |
|
|
arg: "Argument" |
|
|
dim_off: int |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CppCodeTranspiler: |
|
|
|
|
|
name_to_arg: Dict[str, "Argument"] |
|
|
|
|
|
|
|
|
name_to_dim: Dict[str, Dim] = field(default_factory=dict, init=False) |
|
|
|
|
|
def __call__(self, exprs: List[str]) -> Tuple[List[str], Dict[str, str]]: |
|
|
asts = [self.to_ast(expr) for expr in exprs] |
|
|
return self.codegen(asts) |
|
|
|
|
|
def to_ast(self, expr: str) -> _AST: |
|
|
self.cur_expr = expr |
|
|
ast = parser.parse(expr) |
|
|
ast = ToAst()(ast) |
|
|
return ast |
|
|
|
|
|
def codegen(self, asts: List[_AST]) -> Tuple[List[str], Dict[str, str]]: |
|
|
''' |
|
|
Parse an expression group and generate the corresponding C++ code. |
|
|
|
|
|
The syntax of an expression is like below: |
|
|
|
|
|
- `name[expr, expr, ...] -> name[expr, expr, ...]` |
|
|
- `name[expr, expr, ...]:expr -> dim_arg` |
|
|
''' |
|
|
shape_infer_code = [] |
|
|
dim_size_infer_code = {} |
|
|
|
|
|
for ast in asts: |
|
|
if isinstance(ast, DeduceShapeRule): |
|
|
self.dim_cpp_repr = lambda arg_idx, dim_idx: f"inputDims[{arg_idx}].d[{dim_idx}]" |
|
|
shape_infer_code.extend(self.emit_DeduceShapeRule(ast)) |
|
|
elif isinstance(ast, DeduceDimSizeArgRule): |
|
|
self.dim_cpp_repr = lambda arg_idx, dim_idx: f"inputDesc[{arg_idx}].dims.d[{dim_idx}]" |
|
|
dim_size_infer_code[ast.right] = self.emit_DeduceDimSizeArgRule( |
|
|
ast) |
|
|
else: |
|
|
raise ValueError("Unexpected ast: %s" % repr(ast)) |
|
|
|
|
|
return shape_infer_code, dim_size_infer_code |
|
|
|
|
|
@staticmethod |
|
|
def is_cur_identical_dims(item: ShapedTensorAST): |
|
|
return len(item.shape.dims) == 1 and isinstance( |
|
|
item.shape.dims[0], DimAST) and item.shape.dims[0].name == "*" |
|
|
|
|
|
def collect_dims_from_left(self, lefts: List[ShapedTensorAST]): |
|
|
self.name_to_dim.clear() |
|
|
|
|
|
is_left_identical_dims = self.is_cur_identical_dims(lefts[0]) |
|
|
|
|
|
for left in lefts: |
|
|
arg_name = left.arg_name |
|
|
argument = self.name_to_arg[arg_name] |
|
|
for off, dim in enumerate(left.shape.dims): |
|
|
assert isinstance( |
|
|
dim, DimAST |
|
|
), f"Wrong syntax in '{self.cur_expr}', for deduce_shape rule, each named dimension should be a name rather than an expression" |
|
|
self.name_to_dim[dim.name] = Dim(argument, off) |
|
|
return is_left_identical_dims |
|
|
|
|
|
def emit_DeduceShapeRule(self, rule: DeduceShapeRule) -> List[str]: |
|
|
from tensorrt_llm.tools.plugin_gen.core import code |
|
|
|
|
|
is_cur_identical_dims = lambda item: len( |
|
|
item.shape.dims) == 1 and isinstance(item.shape.dims[ |
|
|
0], DimAST) and item.shape.dims[0].name == "*" |
|
|
|
|
|
is_left_identical_dims = self.collect_dims_from_left(rule.left) |
|
|
|
|
|
first_left_tensor = rule.left[0] |
|
|
first_left_tensor_arg = self.name_to_arg[first_left_tensor.arg_name] |
|
|
|
|
|
ret = [] |
|
|
|
|
|
|
|
|
|
|
|
is_right_identical_dims = False |
|
|
for off, item in enumerate(rule.right): |
|
|
is_cur_identical_dims = self.is_cur_identical_dims(item) |
|
|
if is_right_identical_dims and not is_cur_identical_dims: |
|
|
assert is_cur_identical_dims, "Wrong syntax in '%s', for deduce_shape rule, once the left side be X[*], the should all be X[*] format too" % self.cur_expr |
|
|
is_right_identical_dims = is_cur_identical_dims |
|
|
|
|
|
assert is_left_identical_dims == is_right_identical_dims, "Wrong syntax in '%s', for deduce_shape rule, the left and right side should be both X[*] or not" % self.cur_expr |
|
|
|
|
|
for off, tensor in enumerate(rule.right): |
|
|
out_arg = self.name_to_arg[tensor.arg_name] |
|
|
ret.append(code(f"if (outputIndex == {out_arg.offset}) {{")) |
|
|
|
|
|
if is_right_identical_dims: |
|
|
ret.append( |
|
|
code( |
|
|
f" outputDims = inputDims[{first_left_tensor_arg.offset}];" |
|
|
)) |
|
|
else: |
|
|
ret.append( |
|
|
code(f" outputDims.nbDims = {len(tensor.shape.dims)};")) |
|
|
for dim_off, dim in enumerate(tensor.shape.dims): |
|
|
ret.append( |
|
|
code( |
|
|
f" outputDims.d[{dim_off}] = {self.emit_expr(dim)};" |
|
|
)) |
|
|
|
|
|
ret.append(code(f"}}")) |
|
|
|
|
|
return ret |
|
|
|
|
|
def emit_DeduceDimSizeArgRule(self, rule: DeduceDimSizeArgRule) -> str: |
|
|
self.collect_dims_from_left(rule.left) |
|
|
return self.emit_expr(rule.expr) |
|
|
|
|
|
def emit_expr(self, expr: _AST) -> str: |
|
|
if isinstance(expr, NumberAST): |
|
|
return str(expr.value) |
|
|
elif isinstance(expr, DimAST): |
|
|
return self.emit_dim(expr) |
|
|
elif isinstance(expr, BinaryAST): |
|
|
return self.emit_binary(expr) |
|
|
raise ValueError("Unexpected expr: %s" % expr) |
|
|
|
|
|
def emit_dim(self, dim: DimAST) -> str: |
|
|
dim_: Dim = self.name_to_dim[dim.name] |
|
|
repr = self.dim_cpp_repr(dim_.arg.offset, dim_.dim_off) |
|
|
return repr |
|
|
|
|
|
def emit_binary(self, binary: BinaryAST) -> str: |
|
|
left = self.emit_expr(binary.left) |
|
|
right = self.emit_expr(binary.right) |
|
|
return f"({left} {binary.op} {right})" |
|
|
|