Spaces:
Build error
Build error
| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| ==============================================================================*/ | |
| namespace tensorflow { | |
| namespace { | |
| // A helper class to make AttrSlice from initializer lists | |
| class Attrs { | |
| public: | |
| Attrs(const std::initializer_list< // NOLINT(runtime/explicit) | |
| std::pair<string, FunctionDefHelper::AttrValueWrapper>> | |
| attrs) { | |
| for (const auto& aval : attrs) { | |
| map_.insert({aval.first, aval.second.proto}); | |
| } | |
| } | |
| operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) | |
| private: | |
| AttrValueMap map_; | |
| }; | |
| typedef FunctionDefHelper FDH; | |
| Status GetOpSig(const string& op, const OpDef** sig) { | |
| return OpRegistry::Global()->LookUpOpDef(op, sig); | |
| } | |
| REGISTER_OP("One") | |
| .Output("y: T") | |
| .Attr("T: {float, double, int32, int64}") | |
| .Doc(R"doc( | |
| Returns a tensor with a single element (1) of type T. | |
| y: A scalar in type T. | |
| )doc"); | |
| TEST(TFunc, SquarePlusOne) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "SquarePlusOne", | |
| // Inputs | |
| {"x: T"}, | |
| // Outputs | |
| {"y: T"}, | |
| // Attrs | |
| {"T: {float, double, int32, int64}"}, | |
| // Nodes | |
| {// a = Square<T>(x) | |
| {{"a"}, "Square", {"x"}, {{"T", "$T"}}}, | |
| // o = One<T>() | |
| // NOTE: We can also have a Cast<Tin, Tout>(x) instead. | |
| {{"o"}, "One", {}, {{"T", "$T"}}}, | |
| // y = Add<T>(a, o) | |
| {{"y"}, "Add", {"a:y", "o:y"}, {{"T", "$T"}}}}, | |
| // Returns | |
| {{"y", "y:z:0"}}); | |
| const char* e = R"P( | |
| SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) { | |
| a = Square[T=$T](x) | |
| o = One[T=$T]() | |
| y = Add[T=$T](a:y, o:y) | |
| return y = y:z:0 | |
| } | |
| )P"; | |
| EXPECT_EQ(DebugString(fdef), e); | |
| // Instantiate one with T=float | |
| InstantiationResult result; | |
| TF_ASSERT_OK( | |
| InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result)); | |
| const char* e2 = R"P( | |
| (x:float) -> (y:float) { | |
| a = Square[T=float](x) | |
| o = One[T=float]() | |
| y = Add[T=float](a, o) | |
| } | |
| )P"; | |
| EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); | |
| EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); | |
| EXPECT_EQ(DebugString(result.nodes), e2); | |
| } | |
| TEST(TFunc, ControlDep) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "ControlDep", | |
| // Inputs | |
| {"x: int32"}, | |
| // Outputs | |
| {"y: int32"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| {// a = Identity<int32>(x) | |
| {{"a"}, "Identity", {"x"}, {{"T", DT_INT32}}}, | |
| // o = NoOp(^a) | |
| {{"o"}, "NoOp", {"^a"}, {}}, | |
| // y = Identity<int32>(a, ^o) | |
| {{"y"}, "Identity", {"a:output:0", "^o"}, {{"T", DT_INT32}}}}, | |
| // Returns | |
| {{"y", "y:output:0"}}); | |
| const char* e = R"P( | |
| ControlDep(x:int32) -> (y:int32) { | |
| a = Identity[T=int32](x) | |
| o = NoOp() @ a | |
| y = Identity[T=int32](a:output:0) @ o | |
| return y = y:output:0 | |
| } | |
| )P"; | |
| EXPECT_EQ(DebugString(fdef), e); | |
| // Instantiate one with T=float | |
| InstantiationResult result; | |
| TF_ASSERT_OK( | |
| InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result)); | |
| const char* e2 = R"P( | |
| (x:int32) -> (y:int32) { | |
| a = Identity[T=int32](x) | |
| o = NoOp() @ a | |
| y = Identity[T=int32](a) @ o | |
| } | |
| )P"; | |
| EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32})); | |
| EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32})); | |
| EXPECT_EQ(DebugString(result.nodes), e2); | |
| } | |
| REGISTER_OP("HasDefaultType") | |
| .Output("out: T") | |
| .Attr("T: {float, double, int32, int64} = DT_FLOAT"); | |
| // This verifies that a function using an op before a type attr (with | |
| // a default) is added, still works. This is important for backwards | |
| // compatibility. | |
| TEST(TFunc, MissingTypeAttr) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "BackCompat", | |
| // Args | |
| {}, | |
| // Return values | |
| {"y: float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| {// y = HasDefaultType(x), T missing, defaults to float | |
| {{"a"}, "HasDefaultType", {}, {}}}, | |
| // Returns | |
| {{"y", "a:out:0"}}); | |
| const char* e = R"P( | |
| BackCompat() -> (y:float) { | |
| a = HasDefaultType() | |
| return y = a:out:0 | |
| } | |
| )P"; | |
| EXPECT_EQ(DebugString(fdef), e); | |
| InstantiationResult result; | |
| TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); | |
| // Should get T=float from Op's default. | |
| const char* e2 = R"P( | |
| () -> (a:float) { | |
| a = HasDefaultType[T=float]() | |
| } | |
| )P"; | |
| EXPECT_EQ(result.arg_types, DataTypeVector()); | |
| EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); | |
| EXPECT_EQ(DebugString(result.nodes), e2); | |
| } | |
| TEST(TFunc, NTimesT) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "NTimesT", | |
| // Inputs | |
| {"x: float", "y: float"}, | |
| // Outputs | |
| {"z: float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| {// a = AddN<N=2>(x, y) | |
| {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}}, | |
| // Returns | |
| {{"z", "a:sum:0"}}); | |
| const char* e = R"P( | |
| NTimesT(x:float, y:float) -> (z:float) { | |
| a = AddN[N=2, T=float](x, y) | |
| return z = a:sum:0 | |
| } | |
| )P"; | |
| EXPECT_EQ(DebugString(fdef), e); | |
| InstantiationResult result; | |
| TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); | |
| const char* e2 = R"P( | |
| (x:float, y:float) -> (a:float) { | |
| a = AddN[N=2, T=float](x, y) | |
| } | |
| )P"; | |
| EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT})); | |
| EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); | |
| EXPECT_EQ(DebugString(result.nodes), e2); | |
| } | |
| // NOTE: This is the simplest Map op. It takes a f:T->U. | |
| REGISTER_OP("Map") | |
| .Input("x: N * T") | |
| .Output("y: N * U") | |
| .Attr("T: type") | |
| .Attr("U: type") | |
| .Attr("N: int >= 1") | |
| // .Attr("func: func_name_with_attr") | |
| .Doc(R"doc( | |
| Applies the 'func' on every input. I.e., | |
| y[i] = func<...>(x[i]) | |
| x: N tensors, each of type T; | |
| y: N tensors, each of type U; | |
| )doc"); | |
| TEST(TFunc, AddSquared) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "AddSquared", | |
| // Args | |
| {"x: N*T"}, | |
| // Return values | |
| {"y: T"}, | |
| // Attrs | |
| {"N:int", "T:{float, double, int32, int64}"}, | |
| // Nodes | |
| {// a = Map<func=Square<$T>,T=$T,U=$T,N=$N>(x) | |
| {{"a"}, | |
| "Map", | |
| {"x"}, | |
| {{"func", FDH::FunctionRef("Square", {{"T", "$T"}})}, | |
| {"T", "$T"}, | |
| {"U", "$T"}, | |
| {"N", "$N"}}}, | |
| // y = AddN<N=$N,T=$T>(a) | |
| {{"y"}, "AddN", {"a:y"}, {{"N", "$N"}, {"T", "$T"}}}}, | |
| {{"y", "y:sum"}}); | |
| const char* e = R"P( | |
| AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) { | |
| a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x) | |
| y = AddN[N=$N, T=$T](a:y) | |
| return y = y:sum | |
| } | |
| )P"; | |
| EXPECT_EQ(DebugString(fdef), e); | |
| // Instantiate one with T=float | |
| InstantiationResult result; | |
| TF_ASSERT_OK(InstantiateFunction(fdef, Attrs({{"N", 3}, {"T", DT_FLOAT}}), | |
| GetOpSig, &result)); | |
| const char* e2 = R"P( | |
| (x_0:float, x_1:float, x_2:float) -> (y:float) { | |
| a = Map[N=3, T=float, U=float, func=Square[T=float]](x_0, x_1, x_2) | |
| y = AddN[N=3, T=float](a, a:1, a:2) | |
| } | |
| )P"; | |
| EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT})); | |
| EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); | |
| EXPECT_EQ(DebugString(result.nodes), e2); | |
| } | |
| TEST(TFunc, ControlDeps) { | |
| auto fdef = FDH::Define( | |
| // Name | |
| "ControlDeps", | |
| // Args | |
| {"x: float"}, | |
| // Return values | |
| {}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| { | |
| {{"a"}, "One", {}, {{"T", DT_FLOAT}}, {"x"}}, | |
| {{"u"}, "NoOp", {}, {}, {"a"}}, | |
| {{"b"}, "One", {}, {{"T", DT_FLOAT}}, {"u"}}, | |
| {{"v"}, "NoOp", {}, {}, {"b"}}, | |
| {{"c"}, "One", {}, {{"T", DT_FLOAT}}, {"a", "v"}}, | |
| }); | |
| const char* e = R"P( | |
| ControlDeps(x:float) -> () { | |
| a = One[T=float]() @ x | |
| u = NoOp() @ a | |
| b = One[T=float]() @ u | |
| v = NoOp() @ b | |
| c = One[T=float]() @ a, v | |
| } | |
| )P"; | |
| EXPECT_EQ(DebugString(fdef), e); | |
| InstantiationResult result; | |
| TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); | |
| const char* e2 = R"P( | |
| (x:float) -> () { | |
| a = One[T=float]() @ x | |
| u = NoOp() @ a | |
| b = One[T=float]() @ u | |
| v = NoOp() @ b | |
| c = One[T=float]() @ a, v | |
| } | |
| )P"; | |
| EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); | |
| EXPECT_EQ(result.ret_types, DataTypeVector({})); | |
| EXPECT_EQ(DebugString(result.nodes), e2); | |
| } | |
| TEST(TFunc, XTimesTwo) { | |
| auto expect = R"P( | |
| XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { | |
| two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() | |
| scale = Cast[DstT=$T, SrcT=int64](two:output:0) | |
| y = Mul[T=$T](x, scale:y:0) | |
| return y = y:z:0 | |
| } | |
| )P"; | |
| EXPECT_EQ(expect, DebugString(test::function::XTimesTwo())); | |
| } | |
| TEST(TFunc, WXPlusB) { | |
| auto expect = R"P( | |
| WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) { | |
| mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x) | |
| y = Add[T=$T](mm:product:0, b) | |
| return y = y:z:0 | |
| } | |
| )P"; | |
| EXPECT_EQ(expect, DebugString(test::function::WXPlusB())); | |
| } | |
| TEST(TFunc, Body_TypeList) { | |
| const Tensor kZero = test::AsScalar<int32>(0); | |
| auto fdef = FDH::Create( | |
| // Name | |
| "Test", | |
| // Args | |
| {"i:float"}, | |
| // Return values | |
| {"o:float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| {{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}}, | |
| {{"s"}, | |
| "Split", | |
| {"zero:output:0", "i"}, | |
| {{"num_split", 4}, {"T", DT_FLOAT}}}, | |
| {{"l"}, "Mul", {"s:output:0", "s:output:1"}, {{"T", DT_FLOAT}}}, | |
| {{"r"}, "Mul", {"s:output:2", "s:output:3"}, {{"T", DT_FLOAT}}}, | |
| {{"x"}, | |
| "_ListToArray", | |
| {"l:z", "r:z"}, | |
| {{"N", 2}, | |
| {"T", DT_FLOAT}, | |
| {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, | |
| {{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}}, | |
| {{"o", "o:sum:0"}}); | |
| const char* e = R"P( | |
| Test(i:float) -> (o:float) { | |
| zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() | |
| s = Split[T=float, num_split=4](zero:output:0, i) | |
| l = Mul[T=float](s:output:0, s:output:1) | |
| r = Mul[T=float](s:output:2, s:output:3) | |
| x = _ListToArray[N=2, T=float, Tin={float, float}](l:z, r:z) | |
| o = AddN[N=2, T=float](x:output) | |
| return o = o:sum:0 | |
| } | |
| )P"; | |
| EXPECT_EQ(DebugString(fdef), e); | |
| InstantiationResult result; | |
| TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); | |
| const char* e2 = R"P( | |
| (i:float) -> (o:float) { | |
| zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() | |
| s = Split[T=float, num_split=4](zero, i) | |
| l = Mul[T=float](s, s:1) | |
| r = Mul[T=float](s:2, s:3) | |
| x = _ListToArray[N=2, T=float, Tin={float, float}](l, r) | |
| o = AddN[N=2, T=float](x, x:1) | |
| } | |
| )P"; | |
| EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); | |
| EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); | |
| EXPECT_EQ(DebugString(result.nodes), e2); | |
| } | |
| REGISTER_OP("Cond") | |
| .Input("input: Tin") | |
| .Output("output: out_types") | |
| .Attr("Tin: list(type)") | |
| .Attr("out_types: list(type)") | |
| .Attr("cond: func") | |
| .Attr("then_branch: func") | |
| .Attr("else_branch: func") | |
| .Doc(R"doc( | |
| output = Cond(input) ? then_branch(input) : else_branch(input) | |
| cond: A function takes 'input' and returns a scalar. | |
| then_branch: A function takes 'input' and returns 'output'. | |
| else_branch: A function takes 'input' and returns 'output'. | |
| )doc"); | |
| TEST(TFunc, Body_Array_List_Converter) { | |
| auto fdef = FDH::Define( | |
| // Name | |
| "MySelect", | |
| // Args | |
| {"x:float"}, | |
| // Return values | |
| {"z:float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| { | |
| {{"y"}, | |
| "Cond", | |
| {"x"}, | |
| {{"Tin", DataTypeSlice{DT_FLOAT}}, | |
| {"out_types", DataTypeSlice{DT_FLOAT}}, | |
| {"cond", FDH::FunctionRef("MyCond")}, | |
| {"then_branch", FDH::FunctionRef("MyThen")}, | |
| {"else_branch", FDH::FunctionRef("MyElse")}}}, | |
| {{"z"}, | |
| "Cond", | |
| {"y", "y"}, | |
| {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, | |
| {"out_types", DataTypeSlice{DT_FLOAT}}, | |
| {"cond", FDH::FunctionRef("MyCond2")}, | |
| {"then_branch", FDH::FunctionRef("MyThen2")}, | |
| {"else_branch", FDH::FunctionRef("MyElse2")}}}, | |
| }); | |
| const char* e = R"P( | |
| MySelect(x:float) -> (z:float) { | |
| y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x) | |
| z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y:output:0, y:output:0) | |
| return z = z:output:0 | |
| } | |
| )P"; | |
| EXPECT_EQ(DebugString(fdef), e); | |
| InstantiationResult result; | |
| TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); | |
| const char* e2 = R"P( | |
| (x:float) -> (z:float) { | |
| y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x) | |
| z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y) | |
| } | |
| )P"; | |
| EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); | |
| EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); | |
| EXPECT_EQ(DebugString(result.nodes), e2); | |
| } | |
| static void HasError(const Status& s, const string& substr) { | |
| EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) | |
| << ">>" << s << "<<, expected substring >>" << substr << "<<"; | |
| } | |
| TEST(InstantiateErrors, Not_Sufficient_Attrs) { | |
| auto fdef = | |
| FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); | |
| InstantiationResult result; | |
| HasError( | |
| InstantiateFunction(fdef, Attrs({{"U", DT_FLOAT}}), GetOpSig, &result), | |
| "Attr T is not found from "); | |
| } | |
| TEST(InstantiateErrors, Too_Many_Attrs) { | |
| auto fdef = | |
| FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, Attrs({{"T", DT_INT32}, {"U", DT_FLOAT}}), | |
| GetOpSig, &result), | |
| "Attr U is not found in "); | |
| } | |
| TEST(InstantiateErrors, AttrValue_Value_Placeholder) { | |
| auto fdef = | |
| FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); | |
| InstantiationResult result; | |
| HasError( | |
| InstantiateFunction(fdef, Attrs({{"T", "$bad"}}), GetOpSig, &result), | |
| "AttrValue had value with unexpected type 'placeholder'\n\tfor attr 'T'"); | |
| } | |
| TEST(InstantiateErrors, Unbounded_Attr) { | |
| auto fdef = FDH::Define("test", {}, {}, {"T:{float, double, int32, int64}"}, | |
| { | |
| {{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}}, | |
| }); | |
| InstantiationResult result; | |
| HasError( | |
| InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result), | |
| "Failed to bind all placeholders"); | |
| } | |
| TEST(InstantiateErrors, DupArgs) { | |
| auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "Duplicated arg name"); | |
| } | |
| TEST(InstantiateErrors, Dup_Node_Names) { | |
| auto fdef = FDH::Define("test", {"x:float"}, {}, {}, | |
| { | |
| {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, | |
| {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, | |
| }); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "Duplicated ret name"); | |
| } | |
| TEST(InstantiateErrors, Node_Arg_Notfound) { | |
| auto fdef = FDH::Create("test", {"x:float"}, {}, {}, | |
| { | |
| {{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}}, | |
| }, | |
| {}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "input z is not found"); | |
| } | |
| TEST(InstantiateErrors, Node_Arg_TypeMismatch) { | |
| auto fdef = FDH::Define("test", {"x:float"}, {}, {}, | |
| { | |
| {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}}, | |
| }); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "input x[0] expected type int32 != float, the type of x[0]"); | |
| } | |
| TEST(InstantiateErrors, Node_Arg_ControlMissing) { | |
| auto fdef = | |
| FDH::Define("test", {"x:float"}, {}, {}, | |
| { | |
| {{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}}, | |
| }); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "input[2] == '^z', is not found."); | |
| } | |
| TEST(InstantiateErrors, FuncRet_Missing) { | |
| auto fdef = FDH::Create("test", {}, {"y: float"}, {}, | |
| { | |
| {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, | |
| }, | |
| {}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "Return y missing"); | |
| } | |
| TEST(InstantiateErrors, FuncRet_NotFound) { | |
| auto fdef = FDH::Create("test", {}, {"y: float"}, {}, | |
| { | |
| {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, | |
| }, | |
| {{"y", "z"}}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "Return y -> z is not found"); | |
| } | |
| TEST(InstantiateErrors, FuncRet_NameMismatch) { | |
| auto fdef = FDH::Create("test", {}, {"y: float"}, {}, | |
| { | |
| {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, | |
| }, | |
| {{"z", "x:y:0"}}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "Return y missing"); | |
| } | |
| // TODO(josh11b): Make this an error. | |
| // TEST(InstantiateErrors, FuncRet_Extra) { | |
| // auto fdef = FDH::Create("test", {}, {"y: float"}, {}, | |
| // { | |
| // {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, | |
| // }, | |
| // {{"y", "x:y:0"}, {"z", "x:y:0"}}); | |
| // InstantiationResult result; | |
| // HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| // "ret is not found"); | |
| // } | |
| TEST(InstantiateErrors, FuncRet_TypeMismatch) { | |
| auto fdef = FDH::Define("test", {}, {"y: float"}, {}, | |
| { | |
| {{"y"}, "One", {}, {{"T", DT_DOUBLE}}}, | |
| }); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "Invalid ret types y : float vs. double\n\tIn function output y"); | |
| } | |
| TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "MySelect", | |
| // Args | |
| {"x: float"}, | |
| // Return values | |
| {"y: float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| { | |
| {{"y"}, | |
| "Cond", | |
| {"x", "x"}, | |
| {{"tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, | |
| {"cond", FDH::FunctionRef("MyCond2")}, | |
| {"then_branch", FDH::FunctionRef("MyThen2")}, | |
| {"else_branch", FDH::FunctionRef("MyElse2")}}}, | |
| }, | |
| {{"y", "y:output"}}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "type attr not found: out_types"); | |
| } | |
| TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "MySelect", | |
| // Args | |
| {"x: float"}, | |
| // Return values | |
| {"y: float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| { | |
| {{"y"}, | |
| "Cond", | |
| {"x", "x"}, | |
| {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, | |
| {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, | |
| {"cond", FDH::FunctionRef("MyCond2")}, | |
| {"then_branch", FDH::FunctionRef("MyThen2")}, | |
| {"else_branch", FDH::FunctionRef("MyElse2")}}}, | |
| }, | |
| {{"y", "y:output"}}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "Invalid ret types"); | |
| } | |
| TEST(InstantiateErrors, TypeList_Missing_Arg) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "MySelect", | |
| // Args | |
| {"x: float"}, | |
| // Return values | |
| {"y: float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| { | |
| {{"y"}, | |
| "Cond", | |
| {"x", "unknown"}, | |
| {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, | |
| {"out_types", DataTypeSlice{DT_FLOAT}}, | |
| {"cond", FDH::FunctionRef("MyCond2")}, | |
| {"then_branch", FDH::FunctionRef("MyThen2")}, | |
| {"else_branch", FDH::FunctionRef("MyElse2")}}}, | |
| }, | |
| {{"y", "y:output"}}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "input unknown is not found"); | |
| } | |
| TEST(InstantiateErrors, TooManyInputs) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "TooManyInputs", | |
| // Inputs | |
| {"x: float", "y: float"}, | |
| // Outputs | |
| {"z: float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| {// a = AddN<N=2>(x, y, x) | |
| {{"a"}, "AddN", {"x", "y", "x"}, {{"T", DT_FLOAT}, {"N", 2}}}}, | |
| // Returns | |
| {{"z", "a:sum:0"}}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "Expected input[2] == 'x' to be a control input."); | |
| } | |
| TEST(InstantiateErrors, TooFewInputs) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "TooFewInputs", | |
| // Inputs | |
| {"x: float", "y: float"}, | |
| // Outputs | |
| {"z: float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| {// a = AddN<N=3>(x, y) | |
| {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}}, | |
| // Returns | |
| {{"z", "a:sum:0"}}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "Attempt to access beyond input size: 2 >= 2"); | |
| } | |
| TEST(InstantiateErrors, TooManyInputsFromArray1) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "TooManyInputsFromArray", | |
| // Inputs | |
| {"x: float", "y: float"}, | |
| // Outputs | |
| {"z: float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| {// a = _ListToArray(x,y) | |
| {{"a"}, | |
| "_ListToArray", | |
| {"x", "y"}, | |
| {{"N", 2}, | |
| {"T", DT_FLOAT}, | |
| {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, | |
| // b = AddN<N=2>(a, y) | |
| {{"b"}, "AddN", {"a:output", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}}, | |
| // Returns | |
| {{"z", "a:sum:0"}}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "Expected input[1] == 'y' to be a control input."); | |
| } | |
| TEST(InstantiateErrors, TooManyInputsFromArray2) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "TooManyInputsFromArray", | |
| // Inputs | |
| {"x: float", "y: float"}, | |
| // Outputs | |
| {"z: float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| {// a = _ListToArray(x,y) | |
| {{"a"}, | |
| "_ListToArray", | |
| {"x", "y"}, | |
| {{"N", 2}, | |
| {"T", DT_FLOAT}, | |
| {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, | |
| // b = AddN<N=2>(x, a) | |
| {{"b"}, "AddN", {"x", "a:output"}, {{"T", DT_FLOAT}, {"N", 2}}}}, | |
| // Returns | |
| {{"z", "a:sum:0"}}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "Input a:output too long for inputs"); | |
| } | |
| TEST(InstantiateErrors, TypeMismatch) { | |
| auto fdef = FDH::Create( | |
| // Name | |
| "TypeMismatch", | |
| // Inputs | |
| {"x: float", "y: int32"}, | |
| // Outputs | |
| {"z: float"}, | |
| // Attrs | |
| {}, | |
| // Nodes | |
| {// a = AddN<N=2>(x, y) | |
| {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}}, | |
| // Returns | |
| {{"z", "a:sum:0"}}); | |
| InstantiationResult result; | |
| HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), | |
| "input inputs[1] expected type float != int32, the type of y[0]"); | |
| } | |
| TEST(FunctionCallFrame, Void_Void) { | |
| FunctionCallFrame frame({}, {}); | |
| TF_EXPECT_OK(frame.SetArgs({})); | |
| auto a = test::AsTensor<float>({100}); | |
| HasError(frame.SetArgs({a}), "Invalid argument"); | |
| Tensor v; | |
| HasError(frame.GetArg(0, &v), "Invalid argument"); | |
| HasError(frame.SetRetval(0, v), "Invalid argument"); | |
| std::vector<Tensor> rets; | |
| TF_EXPECT_OK(frame.GetRetvals(&rets)); | |
| EXPECT_EQ(rets.size(), 0); | |
| } | |
| TEST(FunctionCallFrame, Float_Float_Float) { | |
| FunctionCallFrame frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT}); | |
| HasError(frame.SetArgs({}), "Invalid argument: Expects 2 arguments"); | |
| auto a = test::AsTensor<float>({100}); | |
| auto b = test::AsTensor<float>({200}); | |
| auto c = test::AsTensor<int64>({300}); | |
| HasError(frame.SetArgs({a, c}), | |
| "Invalid argument: Expects arg[1] to be float"); | |
| TF_EXPECT_OK(frame.SetArgs({a, b})); | |
| Tensor v; | |
| HasError(frame.GetArg(-1, &v), "Invalid argument"); | |
| HasError(frame.GetArg(2, &v), "Invalid argument"); | |
| TF_EXPECT_OK(frame.GetArg(0, &v)); | |
| test::ExpectTensorEqual<float>(a, v); | |
| TF_EXPECT_OK(frame.GetArg(1, &v)); | |
| test::ExpectTensorEqual<float>(b, v); | |
| v = test::AsTensor<float>({-100}); | |
| HasError(frame.SetRetval(-1, v), "Invalid argument"); | |
| HasError(frame.SetRetval(1, v), "Invalid argument"); | |
| HasError(frame.SetRetval(0, test::AsTensor<int64>({-100})), | |
| "Invalid argument: Expects ret[0] to be float"); | |
| std::vector<Tensor> rets; | |
| HasError(frame.GetRetvals(&rets), "does not have value"); | |
| TF_EXPECT_OK(frame.SetRetval(0, v)); | |
| HasError(frame.SetRetval(0, v), "has already been set"); | |
| TF_EXPECT_OK(frame.GetRetvals(&rets)); | |
| EXPECT_EQ(rets.size(), 1); | |
| test::ExpectTensorEqual<float>(rets[0], v); | |
| } | |
| TEST(Canonicalize, Basic) { | |
| EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT}, | |
| {"transpose_a", false}, | |
| {"transpose_b", false}})), | |
| "MatMul[T=float,transpose_a=false,transpose_b=false]"); | |
| EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT}, | |
| {"transpose_b", false}, | |
| {"transpose_a", false}})), | |
| "MatMul[T=float,transpose_a=false,transpose_b=false]"); | |
| EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_DOUBLE}, | |
| {"transpose_b", true}, | |
| {"transpose_a", false}})), | |
| "MatMul[T=double,transpose_a=false,transpose_b=true]"); | |
| } | |
| TEST(FunctionLibraryDefinitionTest, Find) { | |
| FunctionDefLibrary proto; | |
| *proto.add_function() = test::function::XTimesTwo(); | |
| FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); | |
| EXPECT_EQ(lib_def.Find("XTimes16"), nullptr); | |
| auto expect = R"P( | |
| XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { | |
| two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() | |
| scale = Cast[DstT=$T, SrcT=int64](two:output:0) | |
| y = Mul[T=$T](x, scale:y:0) | |
| return y = y:z:0 | |
| } | |
| )P"; | |
| auto found = lib_def.Find("XTimesTwo"); | |
| ASSERT_NE(found, nullptr); | |
| EXPECT_EQ(expect, DebugString(*found)); | |
| } | |
| TEST(FunctionLibraryDefinitionTest, LookUp) { | |
| FunctionDefLibrary proto; | |
| *proto.add_function() = test::function::XTimesTwo(); | |
| FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); | |
| const OpDef* op_def; | |
| EXPECT_TRUE(!lib_def.LookUpOpDef("XTimes16", &op_def).ok()); | |
| TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &op_def)); | |
| ASSERT_NE(op_def, nullptr); | |
| EXPECT_EQ(op_def->DebugString(), | |
| test::function::XTimesTwo().signature().DebugString()); | |
| const OpRegistrationData* op_reg_data; | |
| TF_EXPECT_OK(lib_def.LookUp("XTimesTwo", &op_reg_data)); | |
| ASSERT_NE(op_reg_data, nullptr); | |
| // Shape inference function is initialized to UnknownShape. | |
| ASSERT_NE(op_reg_data->shape_inference_fn, nullptr); | |
| } | |
| TEST(FunctionLibraryDefinitionTest, AddFunctionDef) { | |
| // Add one function to the proto lib before constructing 'lib_def'. | |
| FunctionDefLibrary proto; | |
| *proto.add_function() = test::function::XTimesTwo(); | |
| FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); | |
| // Add a new function def to the library. | |
| TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB())); | |
| // Test lookup of first function. | |
| const OpDef* first; | |
| TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &first)); | |
| ASSERT_NE(first, nullptr); | |
| EXPECT_EQ(first->DebugString(), | |
| test::function::XTimesTwo().signature().DebugString()); | |
| // Test lookup of second function. | |
| const OpDef* second; | |
| TF_EXPECT_OK(lib_def.LookUpOpDef("WXPlusB", &second)); | |
| ASSERT_NE(second, nullptr); | |
| EXPECT_EQ(second->DebugString(), | |
| test::function::WXPlusB().signature().DebugString()); | |
| // Can't add function with same name as existing op | |
| FunctionDef fdef = test::function::XTimesTwo(); | |
| fdef.mutable_signature()->set_name("Add"); | |
| Status s = lib_def.AddFunctionDef(fdef); | |
| EXPECT_FALSE(s.ok()); | |
| EXPECT_EQ(s.error_message(), | |
| "Cannot add function 'Add' because an op with the same name " | |
| "already exists."); | |
| // Already-added functions don't produce error | |
| TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo())); | |
| TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB())); | |
| } | |
| TEST(FunctionLibraryDefinitionTest, AddGradientDef) { | |
| // AddGradientDef() doesn't check that functions referenced exist (yet?) | |
| FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary()); | |
| // Test adding a gradient (XTimesFour isn't a valid grad function for | |
| // XTimesTwo but that's ok for now) | |
| GradientDef grad; | |
| grad.set_function_name(test::function::XTimesTwo().signature().name()); | |
| grad.set_gradient_func(test::function::XTimesFour().signature().name()); | |
| TF_EXPECT_OK(lib_def.AddGradientDef(grad)); | |
| // Already-added gradients don't produce error | |
| TF_EXPECT_OK(lib_def.AddGradientDef(grad)); | |
| // Test that adding a duplicate gradient fails | |
| grad.set_gradient_func(test::function::XTimes16().signature().name()); | |
| Status s = lib_def.AddGradientDef(grad); | |
| EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); | |
| EXPECT_EQ(s.error_message(), | |
| "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because " | |
| "it already has gradient function 'XTimesFour'"); | |
| } | |
| TEST(FunctionLibraryDefinitionTest, AddLibrary) { | |
| // Create lib def with single function | |
| FunctionDefLibrary proto; | |
| *proto.add_function() = test::function::XTimesTwo(); | |
| FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); | |
| // Add gradient | |
| GradientDef grad; | |
| grad.set_function_name(test::function::XTimesTwo().signature().name()); | |
| grad.set_gradient_func(test::function::XTimesFour().signature().name()); | |
| TF_EXPECT_OK(lib_def.AddGradientDef(grad)); | |
| // Error if you try to add conflicting function | |
| proto.Clear(); | |
| FunctionDef fdef = test::function::XTimesFour(); | |
| fdef.mutable_signature()->set_name( | |
| test::function::XTimesTwo().signature().name()); | |
| *proto.add_function() = fdef; | |
| FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto); | |
| Status s = lib_def.AddLibrary(lib_def2); | |
| EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); | |
| EXPECT_EQ(s.error_message(), | |
| "Cannot add function 'XTimesTwo' because a different function with " | |
| "the same name already exists."); | |
| // Error if you try to add conflicting gradient | |
| proto.Clear(); | |
| grad.set_gradient_func(test::function::XTimes16().signature().name()); | |
| *proto.add_gradient() = grad; | |
| FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto); | |
| s = lib_def.AddLibrary(lib_def3); | |
| EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); | |
| EXPECT_EQ(s.error_message(), | |
| "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because " | |
| "it already has gradient function 'XTimesFour'"); | |
| // No conflicting functions or gradients OK | |
| proto.Clear(); | |
| *proto.add_function() = test::function::XTimesFour(); | |
| grad.set_function_name(test::function::XTimes16().signature().name()); | |
| *proto.add_gradient() = grad; | |
| FunctionLibraryDefinition lib_def4(OpRegistry::Global(), proto); | |
| TF_EXPECT_OK(lib_def.AddLibrary(lib_def4)); | |
| // OK to add the same functions and gradients twice | |
| TF_EXPECT_OK(lib_def.AddLibrary(lib_def)); | |
| } | |
| GradientDef MakeGradDef(const string& f, const string& g) { | |
| GradientDef grad; | |
| grad.set_function_name(f); | |
| grad.set_gradient_func(g); | |
| return grad; | |
| } | |
| TEST(FunctionLibraryDefinitionTest, AddLibrary_Atomic) { | |
| // Create lib def containing two functions with equal names | |
| FunctionDefLibrary proto; | |
| const string x2_name = test::function::XTimesTwo().signature().name(); | |
| const string x4_name = test::function::XTimesFour().signature().name(); | |
| *proto.add_function() = test::function::XTimesTwo(); | |
| FunctionDef fdef = test::function::XTimesFour(); | |
| fdef.mutable_signature()->set_name(x2_name); | |
| *proto.add_function() = fdef; | |
| FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary()); | |
| // Try adding the two functions to lib_def | |
| Status s = lib_def.AddLibrary(proto); | |
| EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); | |
| EXPECT_EQ( | |
| "Cannot add function 'XTimesTwo' because a different function with " | |
| "the same name already exists.", | |
| s.error_message()); | |
| // Verify that none of the functions are added | |
| EXPECT_TRUE(lib_def.Find(x2_name) == nullptr); | |
| // Fix the name in proto but add two gradient names for it | |
| proto.mutable_function(1)->mutable_signature()->set_name(x4_name); | |
| *proto.add_gradient() = MakeGradDef(x2_name, x4_name); | |
| *proto.add_gradient() = MakeGradDef(x2_name, "SecondGradName"); | |
| // Try adding the library and check that nothing was added | |
| s = lib_def.AddLibrary(proto); | |
| EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); | |
| EXPECT_EQ(s.error_message(), | |
| "Cannot assign gradient function 'SecondGradName' to 'XTimesTwo' " | |
| "because it already has gradient function 'XTimesFour'"); | |
| EXPECT_TRUE(lib_def.Find(x2_name) == nullptr); | |
| EXPECT_EQ(0, lib_def.ToProto().function_size()); | |
| EXPECT_EQ(0, lib_def.ToProto().gradient_size()); | |
| } | |
| TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_FuncConflict) { | |
| const string x2_name = test::function::XTimesTwo().signature().name(); | |
| const string x4_name = test::function::XTimesFour().signature().name(); | |
| const string wx_name = test::function::WXPlusB().signature().name(); | |
| // Create FunctionLibraryDefinition with | |
| // (func = XTimesTwo, grad = XTimesFour) | |
| FunctionDefLibrary proto; | |
| *proto.add_function() = test::function::XTimesTwo(); | |
| *proto.add_gradient() = MakeGradDef(x2_name, x4_name); | |
| FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); | |
| EXPECT_EQ(1, lib_def.ToProto().function_size()); | |
| EXPECT_EQ(1, lib_def.ToProto().gradient_size()); | |
| // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo) | |
| // and function (name = XTimesTwo, body = XTimeFour) | |
| FunctionDefLibrary proto2; | |
| *proto2.add_function() = test::function::WXPlusB(); | |
| *proto2.add_gradient() = MakeGradDef(wx_name, x2_name); | |
| *proto2.add_function() = test::function::XTimesFour(); | |
| proto2.mutable_function(1)->mutable_signature()->set_name(x2_name); | |
| FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2); | |
| // Verify that adding lib_def2 will fail because of function conflict | |
| // and WXPlusB is not added. | |
| Status s = lib_def.AddLibrary(lib_def2); | |
| EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); | |
| EXPECT_EQ( | |
| "Cannot add function 'XTimesTwo' because a different function " | |
| "with the same name already exists.", | |
| s.error_message()); | |
| EXPECT_TRUE(lib_def.Find(wx_name) == nullptr); | |
| EXPECT_EQ(1, lib_def.ToProto().function_size()); | |
| EXPECT_EQ(1, lib_def.ToProto().gradient_size()); | |
| } | |
| TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_GradConflict) { | |
| const string x2_name = test::function::XTimesTwo().signature().name(); | |
| const string x4_name = test::function::XTimesFour().signature().name(); | |
| const string wx_name = test::function::WXPlusB().signature().name(); | |
| // Create FunctionLibraryDefinition with | |
| // (func = XTimesTwo, grad = XTimesFour) | |
| FunctionDefLibrary proto; | |
| *proto.add_function() = test::function::XTimesTwo(); | |
| *proto.add_gradient() = MakeGradDef(x2_name, x4_name); | |
| FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); | |
| EXPECT_EQ(1, lib_def.ToProto().function_size()); | |
| EXPECT_EQ(1, lib_def.ToProto().gradient_size()); | |
| // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo) | |
| // and (func = XTimesTwo, grad = WXPlusB) | |
| FunctionDefLibrary proto2; | |
| *proto2.add_function() = test::function::WXPlusB(); | |
| *proto2.add_gradient() = MakeGradDef(wx_name, x2_name); | |
| *proto2.add_function() = test::function::XTimesTwo(); | |
| *proto2.add_gradient() = MakeGradDef(x2_name, wx_name); | |
| FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2); | |
| // Verify that adding lib_def2 will fail because of gradient conflict | |
| // and WXPlusB is not added. | |
| Status s = lib_def.AddLibrary(lib_def2); | |
| EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); | |
| EXPECT_EQ( | |
| "Cannot assign gradient function 'WXPlusB' to 'XTimesTwo'" | |
| " because it already has gradient function 'XTimesFour'", | |
| s.error_message()); | |
| EXPECT_TRUE(lib_def.Find(wx_name) == nullptr); | |
| EXPECT_EQ(1, lib_def.ToProto().function_size()); | |
| EXPECT_EQ(1, lib_def.ToProto().gradient_size()); | |
| } | |
| TEST(FunctionLibraryDefinitionTest, ToProto) { | |
| FunctionDefLibrary proto1; | |
| *proto1.add_function() = test::function::XTimesTwo(); | |
| *proto1.add_function() = test::function::WXPlusB(); | |
| FunctionLibraryDefinition lib_def1(OpRegistry::Global(), proto1); | |
| // Call 'ToProto' and make sure both protos have the same function lib size. | |
| FunctionDefLibrary proto2 = lib_def1.ToProto(); | |
| EXPECT_EQ(proto1.function_size(), proto2.function_size()); | |
| // Initialize 'lib_def2' with proto returned by 'ToProto' call. | |
| FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2); | |
| // Test that the first function exists in both libraries. | |
| const OpDef *f1, *f2, *f3, *f4; | |
| TF_EXPECT_OK(lib_def1.LookUpOpDef("XTimesTwo", &f1)); | |
| TF_EXPECT_OK(lib_def2.LookUpOpDef("XTimesTwo", &f2)); | |
| EXPECT_EQ(f1->DebugString(), f2->DebugString()); | |
| // Test that the second function exists in both libraries. | |
| TF_EXPECT_OK(lib_def1.LookUpOpDef("WXPlusB", &f3)); | |
| TF_EXPECT_OK(lib_def2.LookUpOpDef("WXPlusB", &f4)); | |
| EXPECT_EQ(f3->DebugString(), f4->DebugString()); | |
| } | |
| TEST(FunctionLibraryDefinitionTest, GetAttr_FuncNoAttr) { | |
| FunctionDefLibrary proto; | |
| *proto.add_function() = test::function::XTimesTwo(); | |
| FunctionLibraryDefinition lib(OpRegistry::Global(), proto); | |
| NodeDef ndef; | |
| bool annotation; | |
| // Not a function. | |
| ndef.set_op("Matmul"); | |
| EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); | |
| // A function. No attr defined. | |
| ndef.set_op("XTimesTwo"); | |
| EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); | |
| // ndef defines the attr. But we don't care. | |
| AddNodeAttr("annotation", true, &ndef); | |
| EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); | |
| } | |
| template <typename T> | |
| void SetAttrValue(FunctionDef* fdef, const string& attr, const T& value) { | |
| AttrValue attr_value; | |
| SetAttrValue(value, &attr_value); | |
| fdef->mutable_attr()->insert({attr, attr_value}); | |
| } | |
| TEST(FunctionLibraryDefinitionTest, GetAttr_FuncWithAttr) { | |
| FunctionDefLibrary proto; | |
| auto fdef = proto.add_function(); | |
| *fdef = test::function::XTimesTwo(); | |
| SetAttrValue(fdef, "annotation", true); | |
| SetAttrValue(fdef, "options", "some string data"); | |
| FunctionLibraryDefinition lib(OpRegistry::Global(), proto); | |
| NodeDef ndef; | |
| bool annotation; | |
| // A function. No attr defined in ndef. | |
| ndef.set_op("XTimesTwo"); | |
| TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation)); | |
| EXPECT_EQ(annotation, true); | |
| string str; | |
| TF_EXPECT_OK(lib.GetAttr(ndef, "options", &str)); | |
| EXPECT_EQ(str, "some string data"); | |
| } | |
| TEST(FunctionLibraryDefinitionTest, GetAttr_Gradient) { | |
| FunctionDefLibrary proto; | |
| auto fdef = proto.add_function(); | |
| *fdef = test::function::XTimesTwo(); | |
| SetAttrValue(fdef, "annotation", true); | |
| *fdef = test::function::WXPlusB(); | |
| SetAttrValue(fdef, "annotation", false); | |
| auto func_grad = proto.add_gradient(); | |
| func_grad->set_function_name("XTimesTwo"); | |
| func_grad->set_gradient_func("WXPlusB"); | |
| FunctionLibraryDefinition lib(OpRegistry::Global(), proto); | |
| NodeDef ndef; | |
| ndef.set_op(FunctionLibraryDefinition::kGradientOp); | |
| bool annotation; | |
| EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); | |
| NameAttrList nal; | |
| nal.set_name("XTimesTwo"); | |
| AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef); | |
| TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation)); | |
| EXPECT_EQ(annotation, false); // XTimesTwo's gradient is WXPlusB. | |
| nal.set_name("WXPlusB"); | |
| ndef.clear_attr(); | |
| AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef); | |
| TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation)); | |
| EXPECT_EQ(annotation, false); // WXPlusB has no custom gradient. | |
| } | |
| // TODO(skyewm): this could be more thorough | |
| TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) { | |
| // Equal functions | |
| const FunctionDef fdef1 = test::function::XTimesTwo(); | |
| FunctionDef fdef2 = test::function::XTimesTwo(); | |
| uint64 hash1 = FunctionDefHash(fdef1); | |
| EXPECT_TRUE(FunctionDefsEqual(fdef1, fdef2)); | |
| EXPECT_EQ(hash1, FunctionDefHash(fdef2)); | |
| // Different functions | |
| fdef2 = test::function::XTimesFour(); | |
| EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); | |
| EXPECT_NE(hash1, FunctionDefHash(fdef2)); | |
| // Different signatures | |
| fdef2 = test::function::XTimesTwo(); | |
| fdef2.mutable_signature()->mutable_input_arg(0)->set_name("foo"); | |
| EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); | |
| EXPECT_NE(hash1, FunctionDefHash(fdef2)); | |
| // Descriptions must be equal | |
| fdef2 = test::function::XTimesTwo(); | |
| fdef2.mutable_signature()->mutable_input_arg(0)->set_description("foo"); | |
| EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); | |
| EXPECT_NE(hash1, FunctionDefHash(fdef2)); | |
| // Different NodeDefs | |
| fdef2 = test::function::XTimesTwo(); | |
| NodeDef* ndef = fdef2.add_node_def(); | |
| *ndef = fdef2.node_def(0); | |
| ndef->set_name("new_name"); | |
| EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); | |
| EXPECT_NE(hash1, FunctionDefHash(fdef2)); | |
| // Different return values | |
| fdef2 = test::function::XTimesTwo(); | |
| (*fdef2.mutable_ret())["y"] = "y:z:1"; // originally is "y:z:0" | |
| EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); | |
| EXPECT_NE(hash1, FunctionDefHash(fdef2)); | |
| // Different attributes | |
| fdef2 = test::function::XTimesTwo(); | |
| SetAttrValue(&fdef2, "ExtraAttr", true); | |
| EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); | |
| EXPECT_NE(hash1, FunctionDefHash(fdef2)); | |
| // Multiple equivalent attributes; the two functions should be equal. | |
| fdef2 = test::function::XTimesTwo(); | |
| FunctionDef fdef3 = test::function::XTimesTwo(); | |
| SetAttrValue(&fdef2, "Foo", true); | |
| SetAttrValue(&fdef3, "Foo", true); | |
| SetAttrValue(&fdef2, "Bar", 123); | |
| SetAttrValue(&fdef3, "Bar", 123); | |
| SetAttrValue(&fdef2, "Baz", "abc"); | |
| SetAttrValue(&fdef3, "Baz", "abc"); | |
| EXPECT_TRUE(FunctionDefsEqual(fdef2, fdef3)); | |
| EXPECT_EQ(FunctionDefHash(fdef2), FunctionDefHash(fdef3)); | |
| } | |
| } // end namespace | |
| } // end namespace tensorflow | |