| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #ifndef TF_OPS |
| #define TF_OPS |
|
|
| include "tf_generated_ops.td" |
| include "tf_op_base.td" |
| include "mlir/Interfaces/CallInterfaces.td" |
| include "mlir/Interfaces/ControlFlowInterfaces.td" |
| include "mlir/Interfaces/LoopLikeInterface.td" |
| include "mlir/Interfaces/SideEffectInterfaces.td" |
| include "mlir/IR/OpBase.td" |
| include "mlir/IR/SymbolInterfaces.td" |
|
|
| class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> { |
| let results = (outs |
| TF_VariantTensor:$handle |
| ); |
|
|
| TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>; |
|
|
| let verifier = [{ |
| |
| |
| |
| if (handle_dtype().getSubtypes().size() != 1) { |
| return emitOpError( |
| "must have exactly one subtype in the result variant type"); |
| } |
|
|
| return Verify(*this); |
| }]; |
|
|
| DerivedTypeAttr element_dtype = DerivedTypeAttr< |
| "return getElementTypeOrSelf(element_type());">; |
|
|
| let extraClassDeclaration = [{ |
| |
| TensorType element_type() { return handle_dtype().getSubtypes()[0]; } |
|
|
| |
| |
| VariantType handle_dtype() { |
| return getElementTypeOrSelf(handle().getType()).cast<TF::VariantType>(); |
| } |
| }]; |
| } |
|
|
| def TF_CaseOp : TF_Op<"Case", []> { |
| let summary = [{ |
| An n-way switch statement which calls a single branch function. |
| }]; |
|
|
| let description = [{ |
| An n-way switch statement, implementing the following: |
| ``` |
| switch (branch_index) { |
| case 0: |
| output = branches[0](input); |
| break; |
| case 1: |
| output = branches[1](input); |
| break; |
| ... |
| case [[nbranches-1]]: |
| default: |
| output = branches[nbranches-1](input); |
| break; |
| } |
| ``` |
| }]; |
|
|
| let arguments = (ins |
| I32Tensor:$branch_index, |
| Variadic<TF_Tensor>:$input, |
|
|
| Confined<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches, |
|
|
| |
| |
| BoolAttr:$is_stateless |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$output |
| ); |
|
|
| TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; |
| TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; |
| TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; |
|
|
| let verifier = [{ |
| return Verify(*this); |
| }]; |
|
|
|
|
| let extraClassDeclaration = [{ |
| int num_branches() { return branches().size(); } |
|
|
| |
| FuncOp branch_function(int index) { |
| auto flat_sym_ref = branches()[index].cast<FlatSymbolRefAttr>(); |
| return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, flat_sym_ref); |
| } |
|
|
| |
| void get_branch_functions(SmallVectorImpl<FuncOp> &functions) { |
| functions.reserve(num_branches()); |
| for (int idx : llvm::seq<int>(0, num_branches())) |
| functions.push_back(branch_function(idx)); |
| } |
| }]; |
| } |
|
|
| def TF_CaseRegionOp : TF_Op<"CaseRegion", |
| [SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> { |
| let summary = [{ |
| An n-way switch statement which calls a single branch function. |
| }]; |
|
|
| let description = [{ |
| An n-way switch statement, implementing the following: |
| ``` |
| switch (branch_index) { |
| case 0: |
| output = branches[0](input); |
| break; |
| case 1: |
| output = branches[1](input); |
| break; |
| ... |
| case [[nbranches-1]]: |
| default: |
| output = branches[nbranches-1](input); |
| break; |
| } |
| ``` |
| }]; |
|
|
| let arguments = (ins |
| I32Tensor:$branch_index, |
|
|
| |
| |
| BoolAttr:$is_stateless |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$output |
| ); |
|
|
| let regions = (region VariadicRegion<SizedRegion<1>>:$branches); |
|
|
| let verifier = [{ |
| return Verify(*this); |
| }]; |
|
|
| } |
|
|
| |
| |
| def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect]> { |
| let summary = "Constant tensor op"; |
|
|
| let arguments = (ins |
| ElementsAttr:$value |
| ); |
|
|
| let results = (outs |
| TF_Tensor:$output |
| ); |
|
|
| TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; |
|
|
| let builders = [ |
| OpBuilder<(ins "Attribute":$value)>, |
| OpBuilder<(ins "Type":$type, "Attribute":$value)>, |
| ]; |
| } |
|
|
| def TF_CollectivePermuteOp : TF_Op<"CollectivePermute", []> { |
| let summary = "An Op to permute tensors across replicated TPU instances."; |
|
|
| let description = [{ |
| Each instance supplies its own input. |
|
|
| For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing |
| source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs: |
| `[D, A, B, C]`. |
| }]; |
|
|
| let arguments = (ins |
| TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, |
| I32Tensor:$source_target_pairs |
| ); |
|
|
| let results = (outs |
| TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| } |
|
|
| def TF_XlaAllReduceOp : TF_Op<"XlaAllReduce", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> { |
| let summary = "An Op to reduce inputs across replicated TPU instances."; |
|
|
| let arguments = (ins |
| TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Int32, TF_Uint32]>:$input, |
| TF_Int32Tensor:$group_assignment, |
| TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add", "Mean"]>:$reduce_op |
| ); |
|
|
| let results = (outs |
| TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Int32, TF_Uint32]>:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| } |
|
|
| def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> { |
| let summary = "Creates and returns an empty tensor list."; |
|
|
| let description = [{ |
| All list elements must be tensors of dtype element_dtype and shape compatible |
| with element_shape. |
|
|
| handle: an empty tensor list. |
| element_dtype: the type of elements in the list. |
| element_shape: a shape compatible with that of elements in the list. |
| }]; |
|
|
| let arguments = (ins |
| TF_I32OrI64Tensor:$element_shape, |
| TF_Int32Tensor:$max_num_elements |
| ); |
| } |
|
|
| |
| |
| |
| |
| |
| def TF_IdentityOp : TF_Op<"Identity", [TF_OperandsSameAsResultsTypeOrRef]> { |
| let summary = "Identity op"; |
|
|
| let description = [{ |
| Returns a tensor with the same shape and contents as input. |
| }]; |
|
|
| let arguments = (ins |
| TF_Tensor:$input |
| ); |
|
|
| let results = (outs |
| TF_Tensor:$output |
| ); |
|
|
| TF_DerivedResultTypeAttr T = TF_DerivedResultTypeAttr<0>; |
| } |
|
|
| def TF_IfOp : TF_Op<"If", []> { |
| let summary = "output = cond ? then_branch(input) : else_branch(input)"; |
|
|
| let description = [{ |
| output = cond ? then_branch(input) : else_branch(input) |
|
|
| cond: A Tensor. If the tensor is a scalar of non-boolean type, the |
| scalar is converted to a boolean according to the |
| following rule: if the scalar is a numerical value, non-zero means |
| True and zero means False; if the scalar is a string, non-empty |
| means True and empty means False. If the tensor is not a scalar, |
| being empty means False and being non-empty means True. |
| input: A list of input tensors. |
| then_branch: A function that takes 'inputs' and returns a list of |
| tensors, whose types are the same as what else_branch returns. |
| else_branch: A function that takes 'inputs' and returns a list of |
| tensors. whose types are the same as what then_branch returns. |
| }]; |
|
|
| let arguments = (ins |
| TF_Tensor:$cond, |
| Variadic<TF_Tensor>:$input, |
|
|
| FlatSymbolRefAttr:$then_branch, |
| FlatSymbolRefAttr:$else_branch, |
|
|
| |
| BoolAttr:$is_stateless |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>; |
| TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; |
| TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; |
| TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; |
|
|
| let extraClassDeclaration = [{ |
| |
| FuncOp then_function() { |
| return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, then_branch()); |
| } |
|
|
| |
| FuncOp else_function() { |
| return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, else_branch()); |
| } |
| }]; |
| } |
|
|
| def TF_YieldOp : TF_Op<"Yield", |
| [NoSideEffect, ReturnLike, Terminator, |
| ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> { |
| let summary = "Yield operation"; |
|
|
| let description = [{ |
| The "yield" operation represents a return operation within the conditional |
| and body of structured control flow (e.g., if and while). The operation |
| takes a variable number of operands and produces no results. The number and |
| types of inputs must match the signature of the operation that contains the |
| region. |
| }]; |
|
|
| let arguments = (ins Variadic<AnyType>:$operands); |
| } |
|
|
| def TF_IfRegionOp : TF_Op<"IfRegion", |
| [SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> { |
| let summary = "output = cond ? then_branch output : else_branch output"; |
|
|
| let description = [{ |
| "output = cond ? then_branch output : else_branch output" |
|
|
| cond: A Tensor. If the tensor is a scalar of non-boolean type, the |
| scalar is converted to a boolean according to the |
| following rule: if the scalar is a numerical value, non-zero means |
| True and zero means False; if the scalar is a string, non-empty |
| means True and empty means False. If the tensor is not a scalar, |
| being empty means False and being non-empty means True. |
| then_branch: A region that computes the outputs of the op if cond = true. |
| It returns a list of tensors using tf.yield (as the terminator). The |
| types of these returned tensors is same as that of the else_branch |
| else_branch: A region that computes the outputs of the op if cond = false. |
| It returns a list of tensors using tf.yield (as the terminator). The |
| types of these returned tensors is same as that of the then_branch |
| }]; |
|
|
| let arguments = (ins |
| 0DTensorOf<[I1]>:$cond, |
|
|
| |
| BoolAttr:$is_stateless, |
| |
| |
| |
| OptionalAttr<StrAttr>:$_then_func_name, |
| OptionalAttr<StrAttr>:$_else_func_name |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$output |
| ); |
|
|
| let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch); |
|
|
| let verifier = [{ |
| return Verify(*this); |
| }]; |
|
|
| let builders = [ |
| OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands, |
| "llvm::ArrayRef<::mlir::NamedAttribute>":$attributes, |
| "unsigned":$numRegions), |
| [{ |
| assert(numRegions == 2u && "mismatched number of regions"); |
| build($_builder, $_state, resultTypes, operands, attributes); |
| }]>]; |
| } |
|
|
| def TF_LegacyCallOp : TF_Op<"LegacyCall", |
| [CallOpInterface, NoSideEffect]> { |
| let summary = |
| "returns `f(inputs)`, where `f` is a function."; |
|
|
| let description = [{ |
| The LegacyCall operation represents a direct call to a function that is |
| within the same symbol scope as the call and is mapped to a GraphDef node |
| with the function name as the op name. Unlike a PartitionedCall which |
| represents asynchronously executing a function across multiple devices, a |
| LegacyCall ignores specification for ops in the attached function and |
| instead executes it on the device assigned to this op. |
| }]; |
|
|
| let arguments = (ins |
| Variadic<TF_Tensor>:$args, |
|
|
| FlatSymbolRefAttr:$f, |
| DefaultValuedAttr<BoolAttr, "false">:$_disable_call_shape_inference |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$output |
| ); |
|
|
| let extraClassDeclaration = [{ |
| |
| operand_range getArgOperands() { return args(); } |
|
|
| |
| CallInterfaceCallable getCallableForCallee() { return fAttr(); } |
|
|
| |
| FuncOp func() { |
| return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, f()); |
| } |
| }]; |
| } |
|
|
| def TF_ParseExampleOp : TF_Op<"ParseExample", |
| [NoSideEffect, |
| AttrSizedResultSegments, |
| AttrSizedOperandSegments]> { |
|
|
| let summary = |
| "Transforms a vector of tf.Example protos (as strings) into typed tensors."; |
|
|
| let arguments = (ins |
| TF_StrTensor:$serialized, |
| TF_StrTensor:$names, |
| Variadic<TF_StrTensor>:$sparse_keys, |
| Variadic<TF_StrTensor>:$dense_keys, |
| Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults, |
|
|
| TF_ShapeAttrArray:$dense_shapes, |
| I32ElementsAttr:$result_segment_sizes, |
| I32ElementsAttr:$operand_segment_sizes |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Int64Tensor>:$sparse_indices, |
| Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values, |
| Variadic<TF_Int64Tensor>:$sparse_shapes, |
| Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values |
| ); |
|
|
| TF_DerivedOperandSizeAttr Nsparse = TF_DerivedOperandSizeAttr<2>; |
| TF_DerivedOperandSizeAttr Ndense = TF_DerivedOperandSizeAttr<3>; |
| TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<4>; |
| TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>; |
|
|
| let verifier = ?; |
| } |
|
|
| def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2", |
| [NoSideEffect, |
| AttrSizedResultSegments]> { |
|
|
| let summary = |
| "Transforms a vector of tf.Example protos (as strings) into typed tensors."; |
|
|
| let arguments = (ins |
| TF_StrTensor:$serialized, |
| TF_StrTensor:$names, |
| TF_StrTensor:$sparse_keys, |
| TF_StrTensor:$dense_keys, |
| TF_StrTensor:$ragged_keys, |
| Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults, |
|
|
| Confined<I64Attr, [IntMinValue<0>]>:$num_sparse, |
| TF_ShapeAttrArray:$dense_shapes, |
| I32ElementsAttr:$result_segment_sizes |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Int64Tensor>:$sparse_indices, |
| Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values, |
| Variadic<TF_Int64Tensor>:$sparse_shapes, |
| Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values, |
| Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$ragged_values, |
| |
| Variadic<TensorOf<[TF_Int32, TF_Int64]>>:$ragged_row_splits |
| |
| ); |
|
|
| |
| |
| TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<5>; |
| TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>; |
| TF_DerivedResultTypeListAttr ragged_value_types = |
| TF_DerivedResultTypeListAttr<4>; |
| TF_DerivedResultTypeListAttr ragged_split_types = |
| TF_DerivedResultTypeListAttr<5>; |
|
|
| let verifier = [{ |
| return Verify(*this); |
| }]; |
| } |
|
|
| def TF_PartitionedCallOp : TF_Op<"PartitionedCall", |
| [CallOpInterface, NoSideEffect]> { |
| let summary = |
| "returns `f(inputs)`, where `f`'s body is placed and partitioned."; |
|
|
| let description = [{ |
| Asynchronously executes a function, potentially across multiple devices but |
| within a single process. The kernel places and partitions a given function's |
| underlying graph, and executes each of the partitioned subgraphs as a function. |
| }]; |
|
|
| let arguments = (ins |
| Variadic<TF_Tensor>:$args, |
|
|
| SymbolRefAttr:$f, |
| StrAttr:$config, |
| StrAttr:$config_proto, |
| StrAttr:$executor_type |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$output |
| ); |
|
|
| TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; |
| TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; |
|
|
| let extraClassDeclaration = [{ |
| |
| operand_range getArgOperands() { return args(); } |
|
|
| |
| CallInterfaceCallable getCallableForCallee() { return fAttr(); } |
|
|
| |
| FuncOp func() { |
| return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, f()); |
| } |
| }]; |
|
|
| let verifier = [{ return VerifyPartitionedCall(*this); }]; |
| } |
|
|
| def TF_PlaceholderOp : TF_Op<"Placeholder", [NoSideEffect]> { |
| let summary = "Placeholder op"; |
|
|
| let description = [{ |
| Inserts a placeholder for a tensor that will be always fed. |
| }]; |
|
|
| let arguments = (ins |
| ); |
|
|
| let results = (outs |
| TF_Tensor:$output |
| ); |
|
|
| TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; |
| } |
|
|
| def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect]> { |
| let summary = "Placeholder op"; |
|
|
| let description = [{ |
| A placeholder op that passes through input when its output is not fed. |
| }]; |
|
|
| let arguments = (ins |
| TF_Tensor:$input |
| ); |
|
|
| let results = (outs |
| TF_Tensor:$output |
| ); |
|
|
| TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; |
| DerivedAttr shape = TF_DerivedResultShapeAttr; |
| } |
|
|
| def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall", |
| [CallOpInterface]> { |
| let summary = |
| "returns `f(inputs)`, where `f`'s body is placed and partitioned."; |
|
|
| let description = [{ |
| Asynchronously executes a function, potentially across multiple devices but |
| within a single process. The kernel places and partitions a given function's |
| underlying graph, and executes each of the partitioned subgraphs as a function. |
| }]; |
|
|
| let arguments = (ins |
| Variadic<TF_Tensor>:$args, |
|
|
| FlatSymbolRefAttr:$f, |
| StrAttr:$config, |
| StrAttr:$config_proto, |
| StrAttr:$executor_type |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$output |
| ); |
|
|
| TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; |
| TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; |
|
|
| let extraClassDeclaration = [{ |
| |
| operand_range getArgOperands() { return args(); } |
|
|
| |
| CallInterfaceCallable getCallableForCallee() { return fAttr(); } |
|
|
| |
| FuncOp func() { |
| return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, f()); |
| } |
| }]; |
|
|
| let verifier = [{ return VerifyPartitionedCall(*this); }]; |
| } |
|
|
| def TF_WhileOp : TF_Op<"While", []> { |
| let summary = [{ |
| output = input; While (Cond(output)) { output = Body(output) } |
| }]; |
|
|
| let description = [{ |
| output = input; While (Cond(output)) { output = Body(output) } |
|
|
| input: A list of input tensors whose types are T. |
| output: A list of output tensors whose types are T. |
| cond: A function that takes 'input' and returns a tensor. If the tensor is |
| a scalar of non-boolean, the scalar is converted to a boolean |
| according to the following rule: if the scalar is a numerical |
| value, non-zero means True and zero means False; if the scalar is |
| a string, non-empty means True and empty means False. If the |
| tensor is not a scalar, non-emptiness means True and False |
| otherwise. |
| body: A function that takes a list of tensors and returns another |
| list of tensors. Both lists have the same types as specified |
| by T. |
| }]; |
|
|
| let arguments = (ins |
| Variadic<TF_Tensor>:$input, |
|
|
| FlatSymbolRefAttr:$cond, |
| FlatSymbolRefAttr:$body, |
| DefaultValuedAttr<I64Attr, "10">:$parallel_iterations, |
|
|
| |
| |
| BoolAttr:$is_stateless, |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| UnitAttr:$shape_invariant |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$output |
| ); |
|
|
| TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; |
| TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; |
|
|
| let extraClassDeclaration = [{ |
| |
| FuncOp cond_function() { |
| return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, cond()); |
| } |
|
|
| |
| FuncOp body_function() { |
| return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, body()); |
| } |
| }]; |
| } |
|
|
| def TF_WhileRegionOp : TF_Op<"WhileRegion", |
| [DeclareOpInterfaceMethods<LoopLikeOpInterface>, |
| SingleBlockImplicitTerminator<"YieldOp">]> { |
| let summary = "while operation"; |
| let description = [{ |
| The tf.WhileRegion op represents a while loop using 2 regions and a set of |
| iteration variables. The iteration variables maintained by this Op have the |
| same types as the inputs. The Op executes a while loop described by the |
| following pseudo code: |
|
|
| ``` |
| func WhileRegionOp(inputs) { |
| iteration_vars = inputs; |
| while (cond(iteration_vars)) { |
| iteration_vars = body(iteration_vars); |
| } |
| return iteration_vars; |
| } |
| ``` |
|
|
| `cond` is the condition region and `body` is the body region. Both these |
| regions accept the current value of the iteration variables as inputs. The |
| condition region returns a tensor<i1> which, if false, will exit the loop. |
| The body region computes new values of the iteration variables. The iteration |
| variables are initialized to the Op input, and the results of the |
| tf.WhileRegion op are the final values of the iteration variables. |
|
|
| This implies that the operand and result types for tf.WhileRegion should be |
| the same. Note that the condition and body regions can implicitly capture |
| loop invariant values directly. In canonical form, iteration variables that |
| pass through the loop body unmodified are converted to implicitly captured |
| references to their values outside the loop. |
| }]; |
|
|
| let arguments = (ins |
| Variadic<AnyTensor>:$input, |
|
|
| DefaultValuedAttr<I64Attr, "10">:$parallel_iterations, |
|
|
| |
| |
| BoolAttr:$is_stateless, |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| UnitAttr:$shape_invariant |
| ); |
| let results = (outs Variadic<AnyTensor>:$output); |
|
|
| let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); |
|
|
| let verifier = [{ return Verify(*this); }]; |
| } |
|
|
| def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> { |
| let summary = "List of the given size with empty elements."; |
|
|
| let description = [{ |
| element_shape: the shape of the future elements of the list |
| num_elements: the number of elements to reserve |
| handle: the output list |
| element_dtype: the desired type of elements in the list. |
| }]; |
|
|
| let arguments = (ins |
| TF_I32OrI64Tensor:$element_shape, |
| TF_Int32Tensor:$num_elements |
| ); |
| } |
|
|
| |
| |
| |
| |
| def TF_TPUReplicateMetadataOp : TF_Op<"TPUReplicateMetadata", []> { |
| let summary = [{ |
| Metadata indicating how the TPU computation should be replicated. |
| }]; |
|
|
| let description = [{ |
| This operation holds the metadata common to operations of a `tpu.replicate()` computation subgraph. |
| }]; |
|
|
| let arguments = (ins |
| Confined<I64Attr, [IntMinValue<0>]>:$num_replicas, |
| DefaultValuedAttr<I64Attr, "1">:$num_cores_per_replica, |
| StrAttr:$topology, |
| DefaultValuedAttr<BoolAttr, "true">:$use_tpu, |
| DefaultValuedAttr<I64ArrayAttr, "{}">:$device_assignment, |
| DefaultValuedAttr<I64ArrayAttr, "{}">:$computation_shape, |
| DefaultValuedAttr<StrArrayAttr, "{}">:$host_compute_core, |
| DefaultValuedAttr<StrArrayAttr, "{}">:$padding_map, |
| DefaultValuedAttr<StrAttr, "STEP_MARK_AT_ENTRY">:$step_marker_location, |
| DefaultValuedAttr<BoolAttr, "false">:$allow_soft_placement, |
| DefaultValuedAttr<BoolAttr, "false">:$use_spmd_for_xla_partitioning |
| ); |
|
|
| let results = (outs); |
| } |
|
|
| def TF_VarHandleOp : TF_Op<"VarHandleOp", []> { |
| let summary = "Creates a handle to a Variable resource from its name."; |
|
|
| let description = [{ |
| container: the container this variable is placed in. |
| shared_name: the name by which this variable is referred to. |
| dtype and shape: attributes representing the data type and shape held in the |
| variable. |
|
|
| Example: |
| resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[8, 16], container="foo", shared_name="bar") |
| returns a handle for a variable with name "bar" in container "foo", and the |
| variable holds a tensor of shape [8, 16] and dtype int32. |
| }]; |
|
|
| let arguments = (ins |
| DefaultValuedAttr<StrAttr, "">:$container, |
| DefaultValuedAttr<StrAttr, "">:$shared_name |
| ); |
|
|
| let results = (outs |
| Res<TF_ResourceTensor, "", [TF_VariableAlloc]>:$resource |
| ); |
| } |
|
|
| def TF_EnqueueTPUEmbeddingBatchOp : TF_Op<"EnqueueTPUEmbeddingBatch", [TF_TPUEmbeddingSideEffect]> { |
| let summary = [{ |
| An op that enqueues a list of input batch tensors to TPUEmbedding. |
| }]; |
|
|
| let arguments = (ins |
| Variadic<TF_StrTensor>:$batch, |
| TF_StrTensor:$mode_override, |
|
|
| DefaultValuedAttr<I64Attr, "-1">:$device_ordinal, |
| DefaultValuedAttr<StrArrayAttr, "{}">:$combiners |
| ); |
|
|
| let results = (outs); |
|
|
| TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; |
| } |
|
|
| |
| |
| def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> { |
| let summary = "Performs gradient updates of embedding tables."; |
|
|
| let description = [{ |
| inputs: A TensorList of gradients with which to update embedding tables. |
| This argument has the same length and shapes as the return value of |
| RecvTPUEmbeddingActivations, but contains gradients of the model's loss |
| with respect to the embedding activations. The embedding tables are updated |
| from these gradients via the optimizer specified in the TPU embedding |
| configuration given to tpu.initialize_system. |
| learning_rates: A TensorList of float32 scalars, one for each dynamic learning |
| rate tag: see the comments in |
| |
| Multiple tables can share the same dynamic learning rate tag as specified |
| in the configuration. If the learning rates for all tables are constant, |
| this list should be empty. |
| config: Serialized TPUEmbeddingConfiguration proto. |
| }]; |
|
|
| let arguments = (ins |
| Variadic<TF_Tensor>:$inputs, |
| Variadic<TF_Tensor>:$learning_rates, |
| StrAttr:$config |
| ); |
|
|
| TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; |
| TF_DerivedOperandSizeAttr NN = TF_DerivedOperandSizeAttr<1>; |
| } |
|
|
| |
| |
| def TF__SendTPUEmbeddingGradientsOp : TF_Op<"_SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> { |
| let summary = "Performs gradient updates of embedding tables."; |
|
|
| let description = [{ |
| The gradients argument is a TensorList having the same length and shapes as the |
| return value of _RecvTPUEmbeddingActivations, but contains gradients of the |
| model's loss with respect to the embedding activations. The embedding tables are |
| updated from these gradients via the optimizer specified in the |
| TPUEmbeddingConfiguration proto given to tpu.initialize_system. |
|
|
| gradients: A TensorList of gradients with which to update embedding tables. |
| learning_rates: A TensorList of learning rates used for updating the embedding |
| tables via the optimizer. The length of the TensorList must be equal to the |
| number of dynamic learning rate tags specified in the |
| TPUEmbeddingConfiguration proto. |
| deduplication_data: A Tensor with type=DT_VARIANT containing the deduplication |
| data. The tensor is an XLA nested tuple containing N elements. Each |
| element of the nested tuple is a tuple of rank 1 tensors. Each tensor either |
| contains indices (DT_INT32) for embedding lookup or weights (DT_FLOAT) to |
| apply to the output of the embedding lookup operation. |
| config: Serialized TPUEmbeddingConfiguration proto. |
| }]; |
|
|
| let arguments = (ins |
| Variadic<TF_Tensor>:$gradients, |
| Variadic<TF_Tensor>:$learning_rates, |
| TF_VariantTensor:$deduplication_data, |
| StrAttr:$config |
| ); |
|
|
| TF_DerivedOperandSizeAttr NumTables = TF_DerivedOperandSizeAttr<0>; |
| TF_DerivedOperandSizeAttr NumLearningRateTags = TF_DerivedOperandSizeAttr<1>; |
| } |
|
|
| |
| def TF__RecvTPUEmbeddingDeduplicationDataOp : TF_Op<"_RecvTPUEmbeddingDeduplicationData", []> { |
| let summary = [{ |
| Receives deduplication data (indices and weights). |
| }]; |
|
|
| let description = [{ |
| The deduplication data is a Tensor with type=DT_VARIANT. The tensor itself is an |
| XLA nested tuple containing N elements. Each element of the nested tuple is a |
| tuple of rank 1 tensors. Each tensor either contains indices (DT_INT32) for |
| embedding lookup or weights (DT_FLOAT) to apply to the output of the embedding |
| lookup operation. |
| }]; |
|
|
| let arguments = (ins |
| StrAttr:$config |
| ); |
|
|
| let results = (outs |
| TF_VariantTensor:$output |
| ); |
| } |
|
|
| def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> { |
| let summary = [{ |
| An op which shards the input based on the given sharding attribute. |
| }]; |
|
|
| let arguments = (ins |
| TF_Tensor:$input, |
|
|
| DefaultValuedAttr<StrAttr, "">:$sharding, |
| OptionalAttr<StrAttr>:$_XlaSharding |
| ); |
|
|
| let results = (outs |
| TF_Tensor:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| } |
|
|
| def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> { |
| let summary = "Fetches multiple values from infeed as an XLA tuple."; |
|
|
| let arguments = (ins |
| OptionalAttr<StrAttr>:$_XlaSharding |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$outputs |
| ); |
|
|
| TF_DerivedResultShapeListAttr shapes = TF_DerivedResultShapeListAttr<0>; |
| TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>; |
| } |
|
|
| def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> { |
| let summary = "Formats a string template using a list of tensors."; |
|
|
| let description = [{ |
| Formats a string template using a list of tensors, pretty-printing tensor summaries. |
| }]; |
|
|
| let arguments = (ins |
| Variadic<TF_Tensor>:$inputs, |
|
|
| DefaultValuedAttr<StrAttr, "%s">:$strtemplate, |
| DefaultValuedAttr<StrAttr, "%s">:$placeholder, |
| DefaultValuedAttr<I64Attr, "3">:$summarize |
| ); |
|
|
| let results = (outs |
| TF_StrTensor:$output |
| ); |
|
|
| TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; |
| } |
|
|
| |
| |
| |
|
|
| def TF_BatchDatasetV2Op : TF_Op<"BatchDatasetV2", [NoSideEffect]> { |
| let summary = [{ |
| Creates a dataset that batches `batch_size` elements from `input_dataset`. |
| }]; |
|
|
| let arguments = (ins |
| TF_VariantTensor:$input_dataset, |
| TF_Int64Tensor:$batch_size, |
| TF_BoolTensor:$drop_remainder, |
|
|
| DefaultValuedAttr<BoolAttr, "false">:$parallel_copy, |
| Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types, |
| Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes |
| ); |
|
|
| let results = (outs |
| TF_VariantTensor:$handle |
| ); |
| } |
|
|
| def TF_MapDatasetOp : TF_Op<"MapDataset", [NoSideEffect]> { |
| let summary = [{ |
| Creates a dataset that applies `f` to the outputs of `input_dataset`. |
| }]; |
|
|
| let arguments = (ins |
| TF_VariantTensor:$input_dataset, |
| Variadic<TF_Tensor>:$other_arguments, |
|
|
| SymbolRefAttr:$f, |
| Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types, |
| Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes, |
| DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism, |
| DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality |
| ); |
|
|
| let results = (outs |
| TF_VariantTensor:$handle |
| ); |
|
|
| TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; |
| } |
|
|
| def TF_MapAndBatchDatasetOp : TF_Op<"MapAndBatchDataset", [NoSideEffect]> { |
| let summary = "Creates a dataset that fuses mapping with batching."; |
|
|
| let description = [{ |
| Creates a dataset that applies `f` to the outputs of `input_dataset` and then |
| batches `batch_size` of them. |
|
|
| Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up |
| to `batch_size * num_parallel_batches` copies of `f` in parallel. |
| }]; |
|
|
| let arguments = (ins |
| TF_VariantTensor:$input_dataset, |
| Variadic<TF_Tensor>:$other_arguments, |
| TF_Int64Tensor:$batch_size, |
| TF_Int64Tensor:$num_parallel_calls, |
| TF_BoolTensor:$drop_remainder, |
|
|
| SymbolRefAttr:$f, |
| Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types, |
| Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes, |
| DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality |
| ); |
|
|
| let results = (outs |
| TF_VariantTensor:$handle |
| ); |
|
|
| TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; |
| } |
|
|
| def TF_ParallelMapDatasetOp : TF_Op<"ParallelMapDataset", [NoSideEffect]> { |
| let summary = [{ |
| Creates a dataset that applies `f` to the outputs of `input_dataset`. |
| }]; |
|
|
| let description = [{ |
| Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes |
| up to `num_parallel_calls` copies of `f` in parallel. |
| }]; |
|
|
| let arguments = (ins |
| TF_VariantTensor:$input_dataset, |
| Variadic<TF_Tensor>:$other_arguments, |
| TF_Int32Tensor:$num_parallel_calls, |
|
|
| SymbolRefAttr:$f, |
| Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types, |
| Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes, |
| DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism, |
| DefaultValuedAttr<BoolAttr, "false">:$sloppy, |
| DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality |
| ); |
|
|
| let results = (outs |
| TF_VariantTensor:$handle |
| ); |
|
|
| TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; |
| } |
|
|
| def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> { |
| let summary = [{ |
| Creates a dataset that emits each dim-0 slice of `components` once. |
| }]; |
|
|
| let arguments = (ins |
| Variadic<TF_Tensor>:$components, |
| Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes |
| ); |
|
|
| let results = (outs |
| TF_VariantTensor:$handle |
| ); |
|
|
| TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>; |
| } |
|
|
| def TF_ToBoolOp : TF_Op<"ToBool", [NoSideEffect]> { |
| let summary = "Converts a tensor to a scalar predicate."; |
|
|
| let description = [{ |
| Converts a tensor to a scalar predicate with the following rules: |
|
|
| - For 0D tensors, truthiness is determined by comparing against a "zero" |
| value. For numerical types it is the obvious zero. For strings it is the |
| empty string. |
|
|
| - For >0D tensors, truthiness is determined by looking at the number of |
| elements. If has zero elements, then the result is false. Otherwise the |
| result is true. |
|
|
| This matches the behavior of If and While for determining if a tensor counts |
| as true/false for a branch condition. |
| }]; |
|
|
| let arguments = (ins |
| TF_Tensor:$input |
| ); |
|
|
| let results = (outs |
| 0DTensorOf<[I1]>:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
|
|
| let builders = [ |
| OpBuilder<(ins "Value":$value), |
| [{ |
| build($_builder, $_state, RankedTensorType::get({}, $_builder.getI1Type()), |
| value); |
| }]>]; |
| } |
|
|
| def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> { |
| let summary = "Computes the Bessel i0e function of `x` element-wise."; |
|
|
| let description = [{ |
| Exponentially scaled modified Bessel function of order 0 defined as |
| `bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. |
|
|
| This function is faster and numerically stabler than `bessel_i0(x)`. |
| }]; |
|
|
| let arguments = (ins |
| TF_FloatTensor:$x |
| ); |
|
|
| let results = (outs |
| TF_FloatTensor:$y |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| } |
|
|
| def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> { |
| let summary = "Computes the Bessel i1e function of `x` element-wise."; |
|
|
| let description = [{ |
| Exponentially scaled modified Bessel function of order 0 defined as |
| `bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. |
|
|
| This function is faster and numerically stabler than `bessel_i1(x)`. |
| }]; |
|
|
| let arguments = (ins |
| TF_FloatTensor:$x |
| ); |
|
|
| let results = (outs |
| TF_FloatTensor:$y |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| } |
|
|
| def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> { |
| let summary = "Calls a function placed on a specified TPU device."; |
|
|
| let arguments = (ins |
| Variadic<TF_Tensor>:$args, |
| TF_Int32Tensor:$device_ordinal, |
|
|
| SymbolRefAttr:$f, |
| DefaultValuedAttr<I64Attr, "0">:$autotuner_thresh |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$output |
| ); |
|
|
| TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; |
| TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; |
|
|
| let extraClassDeclaration = [{ |
| |
| operand_range getArgOperands() { return args(); } |
|
|
| |
| CallInterfaceCallable getCallableForCallee() { return fAttr(); } |
|
|
| |
| FuncOp func() { |
| return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, f()); |
| } |
| }]; |
|
|
| let verifier = [{ return VerifyPartitionedCall(*this); }]; |
| } |
|
|
| def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments]> { |
| let summary = [{ |
| Batches all the inputs tensors to the computation done by the function. |
| }]; |
|
|
| let description = [{ |
| So, for example, in the following code |
|
|
| ```python |
|
|
| # This input will be captured. |
| y = tf.placeholder_with_default(1.0, shape=[]) |
|
|
| @tf.Defun(tf.float32) |
| def computation(a): |
| return tf.matmul(a, a) + y |
|
|
| b = gen_batch_ops.batch_function( |
| f=computation |
| in_tensors=[a], |
| captured_tensors=computation.captured_inputs, |
| Tout=[o.type for o in computation.definition.signature.output_arg], |
| num_batch_threads=1, |
| max_batch_size=10, |
| batch_timeout_micros=100000, # 100ms |
| allowed_batch_sizes=[3, 10], |
| batching_queue="") |
| ``` |
|
|
| If more than one session.run call is simultaneously trying to compute `b` |
| the values of `a` will be gathered, non-deterministically concatenated |
| along the first axis, and only one thread will run the computation. |
|
|
| Assumes that all arguments of the function are Tensors which will be batched |
| along their first dimension. |
|
|
| Arguments that are captured, are not batched. The session.run call which does |
| the concatenation, will use the values of the captured tensors available to it. |
| Therefore, typical uses of captured tensors should involve values which remain |
| unchanged across session.run calls. Inference is a good example of this. |
|
|
| SparseTensor is not supported. The return value of the decorated function |
| must be a Tensor or a list/tuple of Tensors. |
| }]; |
|
|
| let arguments = (ins |
| Variadic<TF_Tensor>:$in_tensors, |
| Variadic<TF_Tensor>:$captured_tensors, |
|
|
| SymbolRefAttr:$f, |
| I64Attr:$num_batch_threads, |
| I64Attr:$max_batch_size, |
| I64Attr:$batch_timeout_micros, |
| DefaultValuedAttr<I64Attr, "10">:$max_enqueued_batches, |
| DefaultValuedAttr<I64ArrayAttr, "{}">:$allowed_batch_sizes, |
| StrAttr:$container, |
| StrAttr:$shared_name, |
| StrAttr:$batching_queue, |
| DefaultValuedAttr<BoolAttr, "false">:$enable_large_batch_splitting, |
| I32ElementsAttr:$operand_segment_sizes |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$out_tensors |
| ); |
|
|
| TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; |
| TF_DerivedOperandTypeListAttr Tcaptured = TF_DerivedOperandTypeListAttr<1>; |
| TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; |
| } |
|
|
| def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, |
| WithBroadcastableBinOpBuilder { |
| let summary = "Returns x + y element-wise."; |
|
|
| let description = [{ |
| *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting |
| [here](http: |
| }]; |
|
|
| let arguments = (ins |
| TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64]>:$x, |
| TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64]>:$y |
| ); |
|
|
| let results = (outs |
| TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64]>:$z |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| } |
|
|
| def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, |
| WithBroadcastableBinOpBuilder { |
| let summary = "Returns 0 if the denominator is zero."; |
|
|
| let description = [{ |
| *NOTE*: `DivNoNan` supports broadcasting. More about broadcasting |
| [here](http: |
| }]; |
|
|
| let arguments = (ins |
| TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$x, |
| TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$y |
| ); |
|
|
| let results = (outs |
| TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$z |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| } |
|
|
| def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, |
| WithBroadcastableBinOpBuilder { |
| let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise."; |
|
|
| let description = [{ |
| *NOTE*: `Maximum` supports broadcasting. More about broadcasting |
| [here](http: |
| }]; |
|
|
| let arguments = (ins |
| TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$x, |
| TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$y |
| ); |
|
|
| let results = (outs |
| TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$z |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| } |
|
|
| def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>, |
| WithBroadcastableBinOpBuilder { |
| let summary = "Returns x / y element-wise for real types."; |
|
|
| let description = [{ |
| If `x` and `y` are reals, this will return the floating-point division. |
|
|
| *NOTE*: `Div` supports broadcasting. More about broadcasting |
| [here](http: |
| }]; |
|
|
| let arguments = (ins |
| TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$x, |
| TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$y |
| ); |
|
|
| let results = (outs |
| TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$z |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| } |
|
|
| def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, |
| WithBroadcastableBinOpBuilder { |
| let summary = "Returns x + y element-wise."; |
|
|
| let description = [{ |
| *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting |
| [here](http: |
|
|
| Given two input tensors, the `tf.add` operation computes the sum for every element in the tensor. |
|
|
| Both input and output have a range `(-inf, inf)`. |
| }]; |
|
|
| let arguments = (ins |
| TensorOf<[TF_NumberNotQuantizedOrStr]>:$x, |
| TensorOf<[TF_NumberNotQuantizedOrStr]>:$y |
| ); |
|
|
| let results = (outs |
| TensorOf<[TF_NumberNotQuantizedOrStr]>:$z |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| } |
|
|
| def TF_StatefulStandardNormalV2Op : TF_Op<"StatefulStandardNormalV2", []> { |
| let summary = "Outputs random values from a normal distribution."; |
|
|
| let description = [{ |
| The generated values will have mean 0 and standard deviation 1. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource, |
| TF_Int64Tensor:$algorithm, |
| TF_I32OrI64Tensor:$shape |
| ); |
|
|
| let results = (outs |
| TF_FloatTensor:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; |
| TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; |
| } |
|
|
| def TF_StatefulTruncatedNormalOp : TF_Op<"StatefulTruncatedNormal", []> { |
| let summary = "Outputs random values from a truncated normal distribution."; |
|
|
| let description = [{ |
| The generated values follow a normal distribution with mean 0 and standard |
| deviation 1, except that values whose magnitude is more than 2 standard |
| deviations from the mean are dropped and re-picked. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource, |
| TF_Int64Tensor:$algorithm, |
| TF_I32OrI64Tensor:$shape |
| ); |
|
|
| let results = (outs |
| TF_FloatTensor:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; |
| TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; |
| } |
|
|
| def TF_StatefulUniformOp : TF_Op<"StatefulUniform", []> { |
| let summary = "Outputs random values from a uniform distribution."; |
|
|
| let description = [{ |
| The generated values follow a uniform distribution in the range `[0, 1)`. The |
| lower bound 0 is included in the range, while the upper bound 1 is excluded. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource, |
| TF_Int64Tensor:$algorithm, |
| TF_I32OrI64Tensor:$shape |
| ); |
|
|
| let results = (outs |
| TF_FloatTensor:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; |
| TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; |
| } |
|
|
| def TF_StatefulUniformFullIntOp : TF_Op<"StatefulUniformFullInt", []> { |
| let summary = "Outputs random integers from a uniform distribution."; |
|
|
| let description = [{ |
| The generated values are uniform integers covering the whole range of `dtype`. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource, |
| TF_Int64Tensor:$algorithm, |
| TF_I32OrI64Tensor:$shape |
| ); |
|
|
| let results = (outs |
| TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; |
| TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; |
| } |
|
|
| |
| |
| |
| def TF_StatefulUniformIntOp : TF_Op<"StatefulUniformInt", []> { |
| let summary = "Outputs random integers from a uniform distribution."; |
|
|
| let description = [{ |
| The generated values are uniform integers in the range `[minval, maxval)`. |
| The lower bound `minval` is included in the range, while the upper bound |
| `maxval` is excluded. |
|
|
| The random integers are slightly biased unless `maxval - minval` is an exact |
| power of two. The bias is small for values of `maxval - minval` significantly |
| smaller than the range of the output (either `2^32` or `2^64`). |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource, |
| TF_Int64Tensor:$algorithm, |
| TF_I32OrI64Tensor:$shape, |
| TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$minval, |
| TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$maxval |
| ); |
|
|
| let results = (outs |
| TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; |
| TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<3>; |
| } |
|
|
| def TF_CloseSummaryWriterOp : TF_Op<"CloseSummaryWriter", []> { |
| let summary = "Flushes and closes the summary writer."; |
|
|
| let description = [{ |
| Also removes it from the resource manager. To reopen, use another |
| CreateSummaryFileWriter op. |
|
|
| writer: A handle to the summary writer resource. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryFree]>:$writer |
| ); |
|
|
| let results = (outs); |
| } |
|
|
| |
| def TF_CreateSummaryDbWriterOp : TF_Op<"CreateSummaryDbWriter", []> { |
| let summary = "Creates summary database writer accessible by given resource handle."; |
|
|
| let description = [{ |
| This can be used to write tensors from the execution graph directly |
| to a database. Only SQLite is supported right now. This function |
| will create the schema if it doesn't exist. Entries in the Users, |
| Experiments, and Runs tables will be created automatically if they |
| don't already exist. |
|
|
| writer: Handle to SummaryWriter resource to overwrite. |
| db_uri: For example "file:/tmp/foo.sqlite". |
| experiment_name: Can't contain ASCII control characters or <>. Case |
| sensitive. If empty, then the Run will not be associated with any |
| Experiment. |
| run_name: Can't contain ASCII control characters or <>. Case sensitive. |
| If empty, then each Tag will not be associated with any Run. |
| user_name: Must be valid as both a DNS label and Linux username. If |
| empty, then the Experiment will not be associated with any User. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, |
| TF_StrTensor:$db_uri, |
| TF_StrTensor:$experiment_name, |
| TF_StrTensor:$run_name, |
| TF_StrTensor:$user_name |
| ); |
|
|
| let results = (outs); |
| } |
|
|
| |
| def TF_CreateSummaryFileWriterOp : TF_Op<"CreateSummaryFileWriter", []> { |
| let summary = "Creates a summary file writer accessible by the given resource handle."; |
|
|
| let description = [{ |
| writer: A handle to the summary writer resource |
| logdir: Directory where the event file will be written. |
| max_queue: Size of the queue of pending events and summaries. |
| flush_millis: How often, in milliseconds, to flush the pending events and |
| summaries to disk. |
| filename_suffix: Every event file's name is suffixed with this suffix. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, |
| TF_StrTensor:$logdir, |
| TF_Int32Tensor:$max_queue, |
| TF_Int32Tensor:$flush_millis, |
| TF_StrTensor:$filename_suffix |
| ); |
|
|
| let results = (outs); |
| } |
|
|
| def TF_FlushSummaryWriterOp : TF_Op<"FlushSummaryWriter", []> { |
| let summary = "Flushes the writer's unwritten events."; |
|
|
| let description = [{ |
| writer: A handle to the summary writer resource. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer |
| ); |
|
|
| let results = (outs); |
| } |
|
|
| def TF_ImportEventOp : TF_Op<"ImportEvent", []> { |
| let summary = "Outputs a `tf.Event` protocol buffer."; |
|
|
| let description = [{ |
| When CreateSummaryDbWriter is being used, this op can be useful for |
| importing data from event logs. |
|
|
| writer: A handle to a summary writer. |
| event: A string containing a binary-encoded tf.Event proto. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, |
| TF_StrTensor:$event |
| ); |
|
|
| let results = (outs); |
| } |
|
|
| def TF_SummaryWriterOp : TF_Op<"SummaryWriter", []> { |
| let summary = "Returns a handle to be used to access a summary writer."; |
|
|
| let description = [{ |
| The summary writer is an in-graph resource which can be used by ops to write |
| summaries to event files. |
|
|
| writer: the summary writer resource. Scalar handle. |
| }]; |
|
|
| let arguments = (ins |
| StrAttr:$shared_name, |
| StrAttr:$container |
| ); |
|
|
| let results = (outs |
| Res<TF_ResourceTensor, "", [TF_SummaryAlloc]>:$writer |
| ); |
| } |
|
|
| def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> { |
| let summary = "Writes a `Summary` protocol buffer with audio."; |
|
|
| let description = [{ |
| The summary has up to `max_outputs` summary values containing audio. The |
| audio is built from `tensor` which must be 3-D with shape `[batch_size, |
| frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are |
| assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. |
|
|
| The `tag` argument is a scalar `Tensor` of type `string`. It is used to |
| build the `tag` of the summary values: |
|
|
| * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. |
| * If `max_outputs` is greater than 1, the summary value tags are |
| generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. |
|
|
| writer: A handle to a summary writer. |
| step: The step to write the summary for. |
| tag: Scalar. Used to build the `tag` attribute of the summary values. |
| tensor: 2-D of shape `[batch_size, frames]`. |
| sample_rate: The sample rate of the signal in hertz. |
| max_outputs: Max number of batch elements to generate audio for. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, |
| TF_Int64Tensor:$step, |
| TF_StrTensor:$tag, |
| TF_Float32Tensor:$tensor, |
| TF_Float32Tensor:$sample_rate, |
|
|
| Confined<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_outputs |
| ); |
|
|
| let results = (outs); |
| } |
|
|
| def TF_WriteGraphSummaryOp : TF_Op<"WriteGraphSummary", []> { |
| let summary = "Writes a `GraphDef` protocol buffer to a `SummaryWriter`."; |
|
|
| let description = [{ |
| writer: Handle of `SummaryWriter`. |
| step: The step to write the summary for. |
| tensor: A scalar string of the serialized tf.GraphDef proto. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, |
| TF_Int64Tensor:$step, |
| TF_StrTensor:$tensor |
| ); |
|
|
| let results = (outs); |
| } |
|
|
| def TF_WriteHistogramSummaryOp : TF_Op<"WriteHistogramSummary", []> { |
| let summary = "Writes a histogram summary."; |
|
|
| let description = [{ |
| The generated |
| [`Summary`](https: |
| has one summary value containing a histogram for `values`. |
|
|
| This op reports an `InvalidArgument` error if any value is not finite. |
|
|
| writer: A handle to a summary writer. |
| step: The step to write the summary for. |
| tag: Scalar. Tag to use for the `Summary.Value`. |
| values: Any shape. Values to use to build the histogram. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, |
| TF_Int64Tensor:$step, |
| TF_StrTensor:$tag, |
| TF_IntOrFpTensor:$values |
| ); |
|
|
| let results = (outs); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; |
| } |
|
|
| def TF_WriteImageSummaryOp : TF_Op<"WriteImageSummary", []> { |
| let summary = "Writes a `Summary` protocol buffer with images."; |
|
|
| let description = [{ |
| The summary has up to `max_images` summary values containing images. The |
| images are built from `tensor` which must be 4-D with shape `[batch_size, |
| height, width, channels]` and where `channels` can be: |
|
|
| * 1: `tensor` is interpreted as Grayscale. |
| * 3: `tensor` is interpreted as RGB. |
| * 4: `tensor` is interpreted as RGBA. |
|
|
| The images have the same number of channels as the input tensor. For float |
| input, the values are normalized one image at a time to fit in the range |
| `[0, 255]`. `uint8` values are unchanged. The op uses two different |
| normalization algorithms: |
|
|
| * If the input values are all positive, they are rescaled so the largest one |
| is 255. |
|
|
| * If any input value is negative, the values are shifted so input value 0.0 |
| is at 127. They are then rescaled so that either the smallest value is 0, |
| or the largest one is 255. |
|
|
| The `tag` argument is a scalar `Tensor` of type `string`. It is used to |
| build the `tag` of the summary values: |
|
|
| * If `max_images` is 1, the summary value tag is '*tag*/image'. |
| * If `max_images` is greater than 1, the summary value tags are |
| generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. |
|
|
| The `bad_color` argument is the color to use in the generated images for |
| non-finite input values. It is a `unit8` 1-D tensor of length `channels`. |
| Each element must be in the range `[0, 255]` (It represents the value of a |
| pixel in the output image). Non-finite values in the input tensor are |
| replaced by this tensor in the output image. The default value is the color |
| red. |
|
|
| writer: A handle to a summary writer. |
| step: The step to write the summary for. |
| tag: Scalar. Used to build the `tag` attribute of the summary values. |
| tensor: 4-D of shape `[batch_size, height, width, channels]` where |
| `channels` is 1, 3, or 4. |
| max_images: Max number of batch elements to generate images for. |
| bad_color: Color to use for pixels with non-finite values. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, |
| TF_Int64Tensor:$step, |
| TF_StrTensor:$tag, |
| TensorOf<[TF_Float16, TF_Float32, TF_Uint8]>:$tensor, |
| TF_Uint8Tensor:$bad_color, |
|
|
| Confined<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_images |
| ); |
|
|
| let results = (outs); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; |
| } |
|
|
| def TF_WriteRawProtoSummaryOp : TF_Op<"WriteRawProtoSummary", []> { |
| let summary = "Writes a `Summary` protocol buffer with serialized string `Summary` protocol buffers."; |
|
|
| let description = [{ |
| writer: A handle to a summary writer. |
| step: The step to write the summary for. |
| tensor: A tensor holding one or more serialized `Summary` protobufs to write. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, |
| TF_Int64Tensor:$step, |
| TF_StrTensor:$tensor |
| ); |
|
|
| let results = (outs); |
| } |
|
|
| def TF_WriteScalarSummaryOp : TF_Op<"WriteScalarSummary", []> { |
| let summary = "Writes a `Summary` protocol buffer with scalar values."; |
|
|
| let description = [{ |
| The input `tag` and `value` must have the scalars. |
|
|
| writer: A handle to a summary writer. |
| step: The step to write the summary for. |
| tag: Tag for the summary. |
| value: Value for the summary. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, |
| TF_Int64Tensor:$step, |
| TF_StrTensor:$tag, |
| TF_IntOrFpTensor:$value |
| ); |
|
|
| let results = (outs); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; |
| } |
|
|
| def TF_WriteSummaryOp : TF_Op<"WriteSummary", []> { |
| let summary = "Outputs a `Summary` protocol buffer with a tensor."; |
|
|
| let description = [{ |
| writer: A handle to a summary writer. |
| step: The step to write the summary for. |
| tensor: A tensor to serialize. |
| tag: The summary's tag. |
| summary_metadata: Serialized SummaryMetadata protocol buffer containing |
| plugin-related metadata for this summary. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, |
| TF_Int64Tensor:$step, |
| TF_Tensor:$tensor, |
| TF_StrTensor:$tag, |
| TF_StrTensor:$summary_metadata |
| ); |
|
|
| let results = (outs); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; |
| } |
|
|
| def TF_InitializeTableFromDatasetOp : TF_Op<"InitializeTableFromDataset", []> { |
| let summary = ""; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_LookupTableWrite]>:$table_handle, |
| TF_VariantTensor:$dataset |
| ); |
|
|
| let results = (outs); |
| } |
|
|
| |
| def TF_InitializeTableFromTextFileV2Op : TF_Op<"InitializeTableFromTextFileV2", []> { |
| let summary = "Initializes a table from a text file."; |
|
|
| let description = [{ |
| It inserts one key-value pair into the table for each line of the file. |
| The key and value is extracted from the whole line content, elements from the |
| split line based on `delimiter` or the line number (starting from zero). |
| Where to extract the key and value from a line is specified by `key_index` and |
| `value_index`. |
|
|
| - A value of -1 means use the line number(starting from zero), expects `int64`. |
| - A value of -2 means use the whole line content, expects `string`. |
| - A value >= 0 means use the index (starting at zero) of the split line based |
| on `delimiter`. |
| }]; |
|
|
| let arguments = (ins |
| Arg<TF_ResourceTensor, "", [TF_LookupTableWrite]>:$table_handle, |
| TF_StrTensor:$filename, |
|
|
| Confined<I64Attr, [IntMinValue<-2>]>:$key_index, |
| Confined<I64Attr, [IntMinValue<-2>]>:$value_index, |
| Confined<DefaultValuedAttr<I64Attr, "-1">, [IntMinValue<-1>]>:$vocab_size, |
| DefaultValuedAttr<StrAttr, "\t">:$delimiter |
| ); |
|
|
| let results = (outs); |
| } |
|
|
| |
| def TF_CacheDatasetV2Op : TF_Op<"CacheDatasetV2", []> { |
| let summary = ""; |
|
|
| let arguments = (ins |
| TF_VariantTensor:$input_dataset, |
| TF_StrTensor:$filename, |
| Arg<TF_ResourceTensor, "", [TF_DatasetMemoryCacheRead, TF_DatasetMemoryCacheWrite]>:$cache, |
|
|
| Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types, |
| Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes |
| ); |
|
|
| let results = (outs |
| TF_VariantTensor:$handle |
| ); |
| } |
|
|
| def TF__TPUDeviceOrdinalPlaceholderOp : TF_Op<"_TPUDeviceOrdinalPlaceholder", []> { |
| let summary = [{ |
| Placeholder device ordinal that represents device ordinal of a replicated op. |
| }]; |
|
|
| let description = [{ |
| This op can be used when certain rewrite passes materialize ops that require a |
| device ordinal of a replicated op but replication logic has been abstracted away |
| using tf_device.replicate op. Subsequent rewrite passes must replace this op with |
| a constant output that represents the correct device ordinal of the replicated |
| operations inside a TPU host. |
| }]; |
|
|
| let arguments = (ins); |
|
|
| let results = (outs |
| TF_Int64Tensor:$device_ordinal |
| ); |
| } |
|
|
| def TF_AssignOp : TF_Op<"Assign", []> { |
| let summary = "Update 'ref' by assigning 'value' to it."; |
|
|
| let description = [{ |
| This operation outputs "ref" after the assignment is done. |
| This makes it easier to chain operations that need to use the reset value. |
|
|
| This is a side-effecting operation because it will change the value of its |
| argument "ref" in addition to returning the results. |
| }]; |
|
|
| let arguments = (ins |
| TF_Tensor:$ref, |
| TF_Tensor:$value, |
|
|
| DefaultValuedAttr<BoolAttr, "true">:$validate_shape, |
| DefaultValuedAttr<BoolAttr, "true">:$use_locking |
| ); |
|
|
| let results = (outs |
| TF_Tensor:$output_ref |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| } |
|
|
| def TF_TPUPartitionedInputOp : TF_Op<"TPUPartitionedInput", [NoSideEffect]> { |
| let summary = [{ |
| An op that groups a list of partitioned inputs together. This op |
| }]; |
|
|
| let arguments = (ins |
| Variadic<TF_Tensor>:$inputs, |
|
|
| DefaultValuedAttr<I64Attr, "0">:$partition_dim, |
| OptionalAttr<StrAttr>:$_XlaSharding |
| ); |
|
|
| let results = (outs |
| TF_Tensor:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; |
| } |
|
|
| def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [NoSideEffect]> { |
| let summary = [{ |
| An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned |
| }]; |
|
|
| let description = [{ |
| outputs outside the XLA computation. |
| }]; |
|
|
| let arguments = (ins |
| TF_Tensor:$inputs, |
|
|
| DefaultValuedAttr<I64Attr, "0">:$partition_dim, |
| OptionalAttr<StrAttr>:$_XlaSharding |
| ); |
|
|
| let results = (outs |
| Variadic<TF_Tensor>:$output |
| ); |
|
|
| TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; |
| TF_DerivedResultSizeAttr num_splits = TF_DerivedResultSizeAttr<0>; |
| } |
|
|
| #endif |
|
|