| |
| import dataclasses |
| from collections.abc import Collection, Mapping |
| from enum import auto, Enum |
| from typing import Optional, TYPE_CHECKING, Union |
|
|
| from torch._library.fake_class_registry import FakeScriptObject |
| from torch._subclasses.fake_tensor import is_fake |
|
|
|
|
| if TYPE_CHECKING: |
| import torch |
| from torch._functorch._aot_autograd.schemas import GraphSignature |
|
|
| __all__ = [ |
| "ConstantArgument", |
| "CustomObjArgument", |
| "ExportBackwardSignature", |
| "ExportGraphSignature", |
| "InputKind", |
| "InputSpec", |
| "OutputKind", |
| "OutputSpec", |
| "SymIntArgument", |
| "SymFloatArgument", |
| "SymBoolArgument", |
| "TensorArgument", |
| ] |
|
|
|
|
| @dataclasses.dataclass |
| class TensorArgument: |
| name: str |
|
|
|
|
| @dataclasses.dataclass |
| class TokenArgument: |
| name: str |
|
|
|
|
| @dataclasses.dataclass |
| class SymIntArgument: |
| name: str |
|
|
|
|
| @dataclasses.dataclass |
| class SymFloatArgument: |
| name: str |
|
|
|
|
| @dataclasses.dataclass |
| class SymBoolArgument: |
| name: str |
|
|
|
|
| @dataclasses.dataclass |
| class CustomObjArgument: |
| name: str |
| class_fqn: str |
| fake_val: Optional[FakeScriptObject] = None |
|
|
|
|
| @dataclasses.dataclass |
| class ConstantArgument: |
| name: str |
| value: Union[int, float, bool, str, None] |
|
|
|
|
| ArgumentSpec = Union[ |
| TensorArgument, |
| SymIntArgument, |
| SymFloatArgument, |
| SymBoolArgument, |
| ConstantArgument, |
| CustomObjArgument, |
| TokenArgument, |
| ] |
|
|
|
|
| class InputKind(Enum): |
| USER_INPUT = auto() |
| PARAMETER = auto() |
| BUFFER = auto() |
| CONSTANT_TENSOR = auto() |
| CUSTOM_OBJ = auto() |
| TOKEN = auto() |
|
|
|
|
| @dataclasses.dataclass |
| class InputSpec: |
| kind: InputKind |
| arg: ArgumentSpec |
| target: Optional[str] |
| persistent: Optional[bool] = None |
|
|
| def __post_init__(self): |
| if self.kind == InputKind.BUFFER: |
| assert self.persistent is not None, ( |
| "Failed to specify persistent flag on BUFFER." |
| ) |
| assert isinstance( |
| self.arg, |
| ( |
| TensorArgument, |
| SymIntArgument, |
| SymFloatArgument, |
| SymBoolArgument, |
| ConstantArgument, |
| CustomObjArgument, |
| TokenArgument, |
| ), |
| ), f"got {type(self.arg)}" |
|
|
| def __str__(self): |
| target = "" if self.target is None else f" target='{self.target}'" |
| persistent = "" if self.persistent is None else f" persistent={self.persistent}" |
| return f"{str(self.arg.name)}: {str(self.kind.name)}{target}{persistent}" |
|
|
|
|
| class OutputKind(Enum): |
| USER_OUTPUT = auto() |
| LOSS_OUTPUT = auto() |
| BUFFER_MUTATION = auto() |
| PARAMETER_MUTATION = auto() |
| GRADIENT_TO_PARAMETER = auto() |
| GRADIENT_TO_USER_INPUT = auto() |
| USER_INPUT_MUTATION = auto() |
| TOKEN = auto() |
|
|
|
|
| @dataclasses.dataclass |
| class OutputSpec: |
| kind: OutputKind |
| arg: ArgumentSpec |
| target: Optional[str] |
|
|
| def __post_init__(self): |
| assert isinstance( |
| self.arg, |
| ( |
| TensorArgument, |
| SymIntArgument, |
| SymFloatArgument, |
| SymBoolArgument, |
| ConstantArgument, |
| TokenArgument, |
| CustomObjArgument, |
| ), |
| ), self.arg |
|
|
| def __str__(self): |
| target = "" if self.target is None else f" target='{self.target}'" |
| return f"{str(self.arg.name)}: {str(self.kind.name)}{target}" |
|
|
|
|
| @dataclasses.dataclass |
| class ExportBackwardSignature: |
| gradients_to_parameters: dict[str, str] |
| gradients_to_user_inputs: dict[str, str] |
| loss_output: str |
|
|
|
|
| @dataclasses.dataclass |
| class ExportGraphSignature: |
| """ |
| :class:`ExportGraphSignature` models the input/output signature of Export Graph, |
| which is a fx.Graph with stronger invariants guarantees. |
| |
| Export Graph is functional and does not access "states" like parameters |
| or buffers within the graph via ``getattr`` nodes. Instead, :func:`export` |
| guarantees that parameters, buffers, and constant tensors are lifted out of |
| the graph as inputs. Similarly, any mutations to buffers are not included |
| in the graph either, instead the updated values of mutated buffers are |
| modeled as additional outputs of Export Graph. |
| |
| The ordering of all inputs and outputs are:: |
| |
| Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] |
| Outputs = [*mutated_inputs, *flattened_user_outputs] |
| |
| e.g. If following module is exported:: |
| |
| class CustomModule(nn.Module): |
| def __init__(self) -> None: |
| super(CustomModule, self).__init__() |
| |
| # Define a parameter |
| self.my_parameter = nn.Parameter(torch.tensor(2.0)) |
| |
| # Define two buffers |
| self.register_buffer("my_buffer1", torch.tensor(3.0)) |
| self.register_buffer("my_buffer2", torch.tensor(4.0)) |
| |
| def forward(self, x1, x2): |
| # Use the parameter, buffers, and both inputs in the forward method |
| output = ( |
| x1 + self.my_parameter |
| ) * self.my_buffer1 + x2 * self.my_buffer2 |
| |
| # Mutate one of the buffers (e.g., increment it by 1) |
| self.my_buffer2.add_(1.0) # In-place addition |
| |
| return output |
| |
| |
| mod = CustomModule() |
| ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0))) |
| |
| Resulting Graph is non-functional:: |
| |
| graph(): |
| %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] |
| %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] |
| %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] |
| %x1 : [num_users=1] = placeholder[target=x1] |
| %x2 : [num_users=1] = placeholder[target=x2] |
| %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) |
| %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) |
| %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) |
| %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) |
| %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) |
| return (add_1,) |
| |
| Resulting ExportGraphSignature of the non-functional Graph would be:: |
| |
| # inputs |
| p_my_parameter: PARAMETER target='my_parameter' |
| b_my_buffer1: BUFFER target='my_buffer1' persistent=True |
| b_my_buffer2: BUFFER target='my_buffer2' persistent=True |
| x1: USER_INPUT |
| x2: USER_INPUT |
| |
| # outputs |
| add_1: USER_OUTPUT |
| |
| To get a functional Graph, you can use :func:`run_decompositions`:: |
| |
| mod = CustomModule() |
| ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0))) |
| ep = ep.run_decompositions() |
| |
| Resulting Graph is functional:: |
| |
| graph(): |
| %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] |
| %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] |
| %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] |
| %x1 : [num_users=1] = placeholder[target=x1] |
| %x2 : [num_users=1] = placeholder[target=x2] |
| %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) |
| %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) |
| %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) |
| %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) |
| %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) |
| return (add_2, add_1) |
| |
| Resulting ExportGraphSignature of the functional Graph would be:: |
| |
| # inputs |
| p_my_parameter: PARAMETER target='my_parameter' |
| b_my_buffer1: BUFFER target='my_buffer1' persistent=True |
| b_my_buffer2: BUFFER target='my_buffer2' persistent=True |
| x1: USER_INPUT |
| x2: USER_INPUT |
| |
| # outputs |
| add_2: BUFFER_MUTATION target='my_buffer2' |
| add_1: USER_OUTPUT |
| |
| """ |
|
|
| input_specs: list[InputSpec] |
| output_specs: list[OutputSpec] |
|
|
| |
| @property |
| def parameters(self) -> Collection[str]: |
| return tuple( |
| s.target |
| for s in self.input_specs |
| if s.kind == InputKind.PARAMETER |
| if isinstance(s.target, str) |
| ) |
|
|
| |
| @property |
| def buffers(self) -> Collection[str]: |
| return tuple( |
| s.target |
| for s in self.input_specs |
| if s.kind == InputKind.BUFFER |
| if isinstance(s.target, str) |
| ) |
|
|
| @property |
| def non_persistent_buffers(self) -> Collection[str]: |
| return tuple( |
| s.target |
| for s in self.input_specs |
| if s.kind == InputKind.BUFFER |
| if s.persistent is False |
| if isinstance(s.target, str) |
| ) |
|
|
| |
| @property |
| def lifted_tensor_constants(self) -> Collection[str]: |
| return tuple( |
| s.target |
| for s in self.input_specs |
| if s.kind == InputKind.CONSTANT_TENSOR |
| if isinstance(s.target, str) |
| ) |
|
|
| @property |
| def lifted_custom_objs(self) -> Collection[str]: |
| return tuple( |
| s.target |
| for s in self.input_specs |
| if s.kind == InputKind.CUSTOM_OBJ |
| if isinstance(s.target, str) |
| ) |
|
|
| |
| @property |
| def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]: |
| user_inputs: list[Union[int, float, bool, None, str]] = [] |
| for s in self.input_specs: |
| if s.kind != InputKind.USER_INPUT: |
| continue |
|
|
| if isinstance( |
| s.arg, |
| ( |
| TensorArgument, |
| SymIntArgument, |
| SymFloatArgument, |
| SymBoolArgument, |
| CustomObjArgument, |
| ), |
| ): |
| user_inputs.append(s.arg.name) |
| elif isinstance(s.arg, ConstantArgument): |
| user_inputs.append(s.arg.value) |
| else: |
| raise RuntimeError(f"{s.arg} is not a valid user inputs") |
| return tuple(user_inputs) |
|
|
| |
| |
| @property |
| def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]: |
| user_outputs: list[Union[int, float, bool, None, str]] = [] |
| for s in self.output_specs: |
| if s.kind not in [ |
| OutputKind.USER_OUTPUT, |
| OutputKind.LOSS_OUTPUT, |
| ]: |
| continue |
|
|
| if isinstance( |
| s.arg, |
| (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument), |
| ): |
| user_outputs.append(s.arg.name) |
| elif isinstance(s.arg, ConstantArgument): |
| user_outputs.append(s.arg.value) |
| elif isinstance(s.arg, CustomObjArgument): |
| user_outputs.append(s.arg.name) |
| else: |
| raise RuntimeError(f"{s.arg} is not a valid user output") |
| return tuple(user_outputs) |
|
|
| |
| |
| @property |
| def inputs_to_parameters(self) -> Mapping[str, str]: |
| return _immutable_dict( |
| (s.arg.name, s.target) |
| for s in self.input_specs |
| if s.kind == InputKind.PARAMETER |
| and isinstance(s.arg, TensorArgument) |
| and isinstance(s.target, str) |
| ) |
|
|
| |
| |
| @property |
| def inputs_to_buffers(self) -> Mapping[str, str]: |
| return _immutable_dict( |
| (s.arg.name, s.target) |
| for s in self.input_specs |
| if s.kind == InputKind.BUFFER |
| and isinstance(s.arg, TensorArgument) |
| and isinstance(s.target, str) |
| ) |
|
|
| |
| |
| @property |
| def buffers_to_mutate(self) -> Mapping[str, str]: |
| return _immutable_dict( |
| (s.arg.name, s.target) |
| for s in self.output_specs |
| if s.kind == OutputKind.BUFFER_MUTATION |
| and isinstance(s.arg, TensorArgument) |
| and isinstance(s.target, str) |
| ) |
|
|
| @property |
| def parameters_to_mutate(self) -> Mapping[str, str]: |
| return _immutable_dict( |
| (s.arg.name, s.target) |
| for s in self.output_specs |
| if s.kind == OutputKind.PARAMETER_MUTATION |
| and isinstance(s.arg, TensorArgument) |
| and isinstance(s.target, str) |
| ) |
|
|
| @property |
| def user_inputs_to_mutate(self) -> Mapping[str, str]: |
| return _immutable_dict( |
| (s.arg.name, s.target) |
| for s in self.output_specs |
| if s.kind == OutputKind.USER_INPUT_MUTATION |
| and isinstance(s.arg, TensorArgument) |
| and isinstance(s.target, str) |
| ) |
|
|
| |
| @property |
| def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]: |
| return _immutable_dict( |
| (s.arg.name, s.target) |
| for s in self.input_specs |
| if s.kind == InputKind.CONSTANT_TENSOR |
| and isinstance(s.arg, TensorArgument) |
| and isinstance(s.target, str) |
| ) |
|
|
| @property |
| def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]: |
| return _immutable_dict( |
| (s.arg.name, s.target) |
| for s in self.input_specs |
| if s.kind == InputKind.CUSTOM_OBJ |
| and isinstance(s.arg, CustomObjArgument) |
| and isinstance(s.target, str) |
| ) |
|
|
| @property |
| def backward_signature(self) -> Optional[ExportBackwardSignature]: |
| loss_output = None |
| gradients_to_parameters: dict[str, str] = {} |
| gradients_to_user_inputs: dict[str, str] = {} |
| for spec in self.output_specs: |
| if spec.kind == OutputKind.LOSS_OUTPUT: |
| assert loss_output is None |
| assert isinstance(spec.arg, TensorArgument) |
| loss_output = spec.arg.name |
| elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER: |
| assert isinstance(spec.target, str) |
| assert isinstance(spec.arg, TensorArgument) |
| gradients_to_parameters[spec.arg.name] = spec.target |
| elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT: |
| assert isinstance(spec.target, str) |
| assert isinstance(spec.arg, TensorArgument) |
| gradients_to_user_inputs[spec.arg.name] = spec.target |
|
|
| if loss_output is None: |
| return None |
|
|
| return ExportBackwardSignature( |
| loss_output=loss_output, |
| gradients_to_parameters=gradients_to_parameters, |
| gradients_to_user_inputs=gradients_to_user_inputs, |
| ) |
|
|
| |
| |
| |
| @property |
| def assertion_dep_token(self) -> Optional[Mapping[int, str]]: |
| return None |
|
|
| @property |
| def input_tokens(self) -> Collection[str]: |
| input_tokens = [] |
| for s in self.input_specs: |
| if s.kind == InputKind.TOKEN: |
| assert isinstance(s.arg, TokenArgument) |
| input_tokens.append(s.arg.name) |
| return tuple(input_tokens) |
|
|
| @property |
| def output_tokens(self) -> Collection[str]: |
| output_tokens = [] |
| for s in self.output_specs: |
| if s.kind == OutputKind.TOKEN: |
| assert isinstance(s.arg, TokenArgument) |
| output_tokens.append(s.arg.name) |
| return tuple(output_tokens) |
|
|
| def __post_init__(self) -> None: |
| assertion_dep_token = self.assertion_dep_token |
| if assertion_dep_token is None: |
| return |
| assert len(assertion_dep_token) == 1 |
| assertion_dep_token_index = next(iter(assertion_dep_token.keys())) |
| assert ( |
| len(self.user_outputs) + len(self.buffers_to_mutate) |
| == assertion_dep_token_index |
| ) |
|
|
| def replace_all_uses(self, old: str, new: str): |
| """ |
| Replace all uses of the old name with new name in the signature. |
| """ |
| assert isinstance(old, str) |
| assert isinstance(new, str) |
| arg_types = ( |
| TensorArgument, |
| SymIntArgument, |
| SymFloatArgument, |
| SymBoolArgument, |
| CustomObjArgument, |
| TokenArgument, |
| ) |
| for o in self.output_specs: |
| if isinstance(o.arg, arg_types): |
| if o.arg.name == old: |
| o.arg.name = new |
| for i in self.input_specs: |
| if isinstance(i.arg, arg_types): |
| if i.arg.name == old: |
| i.arg.name = new |
|
|
| def get_replace_hook(self, replace_inputs=False): |
| def _(old, new, user): |
| if user.op == "output": |
| self.replace_all_uses(old.name, new) |
| if replace_inputs and old.op == "placeholder": |
| self.replace_all_uses(old.name, new) |
|
|
| return _ |
|
|
| def __str__(self): |
| input_specs = "\n".join(str(s) for s in self.input_specs) |
| output_specs = "\n".join(str(s) for s in self.output_specs) |
| return f"\n# inputs\n{input_specs}\n\n# outputs\n{output_specs}\n" |
|
|
|
|
| def _immutable_dict(items): |
| """ |
| Creates a mapping where items cannot be added, deleted, or updated. |
| NOTE: The immutability is shallow (like tuple is an immutable collection). |
| """ |
| from types import MappingProxyType |
|
|
| return MappingProxyType(dict(items)) |
|
|
|
|
| def _make_argument_spec(node, token_names) -> ArgumentSpec: |
| from torch import ScriptObject, SymBool, SymFloat, SymInt |
| from torch._library.fake_class_registry import FakeScriptObject |
|
|
| if isinstance(node, (int, bool, float, type(None), str)): |
| |
| return ConstantArgument(name="", value=node) |
|
|
| assert "val" in node.meta, ( |
| f"{node} is not a constant or a node with a 'val' metadata field" |
| ) |
| val = node.meta["val"] |
| if node.name in token_names: |
| return TokenArgument(name=node.name) |
| elif is_fake(val): |
| return TensorArgument(name=node.name) |
| elif isinstance(val, SymInt): |
| return SymIntArgument(name=node.name) |
| elif isinstance(val, SymFloat): |
| return SymFloatArgument(name=node.name) |
| elif isinstance(val, SymBool): |
| return SymBoolArgument(name=node.name) |
| elif isinstance(val, ScriptObject): |
| return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name()) |
| elif isinstance(val, FakeScriptObject): |
| return CustomObjArgument( |
| name=node.name, class_fqn=val.script_class_name, fake_val=val |
| ) |
| elif isinstance(val, (int, bool, str, float, type(None))): |
| return ConstantArgument(name=node.name, value=val) |
| else: |
| raise AssertionError( |
| f"Encountered an unsupported object of type {type(val)} " |
| f"while writing the metadata for exported program" |
| ) |
|
|
|
|
| def _convert_to_export_graph_signature( |
| graph_signature: "GraphSignature", |
| gm: "torch.fx.GraphModule", |
| non_persistent_buffers: set[str], |
| ) -> "ExportGraphSignature": |
| from torch.utils import _pytree as pytree |
|
|
| is_joint = graph_signature.backward_signature is not None |
|
|
| |
| user_inputs = set(graph_signature.user_inputs) |
| inputs_to_parameters = graph_signature.inputs_to_parameters |
| inputs_to_buffers = graph_signature.inputs_to_buffers |
| user_outputs = set(graph_signature.user_outputs) |
| buffer_mutations = graph_signature.buffers_to_mutate |
| parameter_mutations = graph_signature.parameters_to_mutate |
| user_input_mutations = graph_signature.user_inputs_to_mutate |
| grad_params = ( |
| graph_signature.backward_signature.gradients_to_parameter |
| if is_joint |
| else {} |
| ) |
| grad_user_inputs = ( |
| graph_signature.backward_signature.gradients_to_user_inputs |
| if is_joint |
| else {} |
| ) |
| loss_output = ( |
| graph_signature.backward_signature.loss_output |
| if is_joint |
| else None |
| ) |
| input_tokens = graph_signature.input_tokens |
| output_tokens = graph_signature.output_tokens |
|
|
| inputs = [ |
| _make_argument_spec(node, input_tokens) |
| for node in gm.graph.nodes |
| if node.op == "placeholder" |
| ] |
| outputs = [ |
| _make_argument_spec(node, output_tokens) |
| for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args) |
| ] |
|
|
| def to_input_spec(inp: ArgumentSpec) -> InputSpec: |
| if isinstance(inp, TokenArgument): |
| return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None) |
|
|
| if not isinstance(inp, TensorArgument): |
| return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) |
| name = inp.name |
| if name in user_inputs: |
| return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) |
| elif name in inputs_to_parameters: |
| return InputSpec( |
| kind=InputKind.PARAMETER, |
| arg=inp, |
| target=inputs_to_parameters[name], |
| ) |
| elif name in inputs_to_buffers: |
| return InputSpec( |
| kind=InputKind.BUFFER, |
| arg=inp, |
| target=inputs_to_buffers[name], |
| persistent=(inputs_to_buffers[name] not in non_persistent_buffers), |
| ) |
| else: |
| raise AssertionError(f"Unknown tensor input kind: {name}") |
|
|
| def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec: |
| if isinstance(o, TokenArgument): |
| return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None) |
|
|
| if not isinstance(o, TensorArgument): |
| return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) |
| name = o.name |
| if idx < len(buffer_mutations) + len(parameter_mutations) + len( |
| user_input_mutations |
| ) + len(output_tokens): |
| if name in buffer_mutations: |
| return OutputSpec( |
| kind=OutputKind.BUFFER_MUTATION, |
| arg=o, |
| target=buffer_mutations[name], |
| ) |
| elif name in parameter_mutations: |
| return OutputSpec( |
| kind=OutputKind.PARAMETER_MUTATION, |
| arg=o, |
| target=parameter_mutations[name], |
| ) |
| elif name in user_input_mutations: |
| return OutputSpec( |
| kind=OutputKind.USER_INPUT_MUTATION, |
| arg=o, |
| target=user_input_mutations[name], |
| ) |
| else: |
| raise AssertionError(f"Unknown tensor mutation kind: {name}") |
| else: |
| if name in user_outputs: |
| return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) |
|
|
| elif name in grad_params: |
| return OutputSpec( |
| kind=OutputKind.GRADIENT_TO_PARAMETER, |
| arg=o, |
| target=grad_params[name], |
| ) |
| elif name in grad_user_inputs: |
| return OutputSpec( |
| kind=OutputKind.GRADIENT_TO_USER_INPUT, |
| arg=o, |
| target=grad_user_inputs[name], |
| ) |
| elif name == loss_output: |
| return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) |
|
|
| else: |
| raise AssertionError(f"Unknown tensor output kind: {name}") |
|
|
| input_specs = [to_input_spec(inp) for inp in inputs] |
| output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)] |
| return ExportGraphSignature(input_specs=input_specs, output_specs=output_specs) |
|
|