| | from __future__ import annotations |
| |
|
| | import itertools |
| | from typing import TYPE_CHECKING |
| | from typing_extensions import assert_never |
| |
|
| | from torchgen.api import cpp |
| | from torchgen.api.types import ArgName, Binding, CType, NamedCType |
| | from torchgen.model import ( |
| | Argument, |
| | FunctionSchema, |
| | Return, |
| | SelfArgument, |
| | TensorOptionsArguments, |
| | Type, |
| | ) |
| | from torchgen.utils import concatMap |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from collections.abc import Sequence |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def name(func: FunctionSchema) -> str: |
| | return cpp.name(func) |
| |
|
| |
|
| | def argumenttype_type( |
| | t: Type, |
| | *, |
| | mutable: bool, |
| | binds: ArgName, |
| | remove_non_owning_ref_types: bool = False, |
| | symint: bool = True, |
| | ) -> NamedCType: |
| | |
| | |
| | |
| | |
| | return cpp.argumenttype_type( |
| | t, |
| | mutable=mutable, |
| | binds=binds, |
| | symint=symint, |
| | remove_non_owning_ref_types=remove_non_owning_ref_types, |
| | ) |
| |
|
| |
|
| | def argument_type( |
| | a: Argument, |
| | *, |
| | binds: ArgName, |
| | remove_non_owning_ref_types: bool = False, |
| | symint: bool = True, |
| | ) -> NamedCType: |
| | return argumenttype_type( |
| | a.type, |
| | mutable=a.is_write, |
| | binds=binds, |
| | remove_non_owning_ref_types=remove_non_owning_ref_types, |
| | symint=symint, |
| | ) |
| |
|
| |
|
| | def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType: |
| | |
| | return cpp.returns_type(rs, symint=symint) |
| |
|
| |
|
| | def jit_arguments(func: FunctionSchema) -> list[Argument]: |
| | def to_argument( |
| | a: Argument | TensorOptionsArguments | SelfArgument, |
| | ) -> list[Argument]: |
| | if isinstance(a, Argument): |
| | return [a] |
| | elif isinstance(a, SelfArgument): |
| | return [a.argument] |
| | elif isinstance(a, TensorOptionsArguments): |
| | return [a.dtype, a.layout, a.device, a.pin_memory] |
| | else: |
| | assert_never(a) |
| |
|
| | return list( |
| | concatMap( |
| | to_argument, |
| | itertools.chain( |
| | func.arguments.positional, func.arguments.kwarg_only, func.arguments.out |
| | ), |
| | ) |
| | ) |
| |
|
| |
|
| | def argument( |
| | a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True |
| | ) -> Binding: |
| | return Binding( |
| | nctype=argument_type( |
| | a, |
| | binds=a.name, |
| | remove_non_owning_ref_types=remove_non_owning_ref_types, |
| | symint=symint, |
| | ), |
| | name=a.name, |
| | argument=a, |
| | ) |
| |
|
| |
|
| | def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]: |
| | return [argument(a, symint=symint) for a in jit_arguments(func)] |
| |
|