/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tf_types.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project namespace { // Returns the shape of the given value if it's ranked; returns llvm::None // otherwise. llvm::Optional > GetShape(mlir::Value value) { auto shaped_type = value.getType().cast(); if (shaped_type.hasRank()) return shaped_type.getShape(); return llvm::None; } // Merges cast compatible shapes and returns a more refined shape. The two // shapes are cast compatible if they have the same rank and at each dimension, // either both have same size or one of them is dynamic. Returns false if the // given shapes are not cast compatible. The refined shape is same or more // precise than the two input shapes. bool GetCastCompatibleShape(llvm::ArrayRef a_shape, llvm::ArrayRef b_shape, llvm::SmallVectorImpl* refined_shape) { if (a_shape.size() != b_shape.size()) return false; int64_t rank = a_shape.size(); refined_shape->reserve(rank); for (auto dims : llvm::zip(a_shape, b_shape)) { int64_t dim1 = std::get<0>(dims); int64_t dim2 = std::get<1>(dims); if (mlir::ShapedType::isDynamic(dim1)) { refined_shape->push_back(dim2); continue; } if (mlir::ShapedType::isDynamic(dim2)) { refined_shape->push_back(dim1); continue; } if (dim1 == dim2) { refined_shape->push_back(dim1); continue; } return false; } return true; } } // namespace namespace mlir { namespace TF { //===----------------------------------------------------------------------===// // Utility iterators //===----------------------------------------------------------------------===// OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it) : llvm::mapped_iterator > (*)(Value)>( it, &GetShape) { } ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it) : llvm::mapped_iterator > (*)(Value)>( it, &GetShape) { } //===----------------------------------------------------------------------===// // TF types helper functions //===----------------------------------------------------------------------===// bool TensorFlowType::classof(Type type) { return type.getDialect().getNamespace() == "tf"; } bool TensorFlowRefType::classof(Type type) { return type.isa< #define HANDLE_TF_TYPE(tftype, enumerant, name) #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type, #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type // NOLINTNEXTLINE #include "tf_types.def" >(); } bool TensorFlowTypeWithSubtype::classof(Type type) { return type.isa(); } TensorFlowType TensorFlowRefType::get(Type type) { MLIRContext* ctx = type.getContext(); type = getElementTypeOrSelf(type); if (type.isF16()) { return HalfRefType::get(ctx); } else if (type.isF32()) { return FloatRefType::get(ctx); } else if (type.isF64()) { return DoubleRefType::get(ctx); } else if (type.isBF16()) { return Bfloat16RefType::get(ctx); } else if (auto complex_type = type.dyn_cast()) { Type etype = complex_type.getElementType(); if (etype.isF32()) { return Complex64RefType::get(ctx); } else if (etype.isF64()) { return Complex128RefType::get(ctx); } llvm_unreachable("unexpected complex type"); } else if (auto itype = type.dyn_cast()) { switch (itype.getWidth()) { case 1: return BoolRefType::get(ctx); case 8: return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx)) : Int8RefType::get(ctx); case 16: return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx)) : Int16RefType::get(ctx); case 32: return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx)) : Int32RefType::get(ctx); case 64: return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx)) : Int64RefType::get(ctx); default: llvm_unreachable("unexpected integer type"); } } #define HANDLE_TF_TYPE(tftype, enumerant, name) \ if (auto derived_ty = type.dyn_cast()) \ return tftype##RefType::get(ctx); #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) // NOLINTNEXTLINE #include "tf_types.def" llvm_unreachable("unexpected type kind"); } Type TensorFlowRefType::RemoveRef() { MLIRContext* ctx = getContext(); if (isa()) return mlir::FloatType::getF16(ctx); if (isa()) return mlir::FloatType::getF32(ctx); if (isa()) return mlir::FloatType::getF64(ctx); if (isa()) return mlir::FloatType::getBF16(ctx); if (isa()) return mlir::IntegerType::get(ctx, 1); if (isa()) return mlir::IntegerType::get(ctx, 8); if (isa()) return mlir::IntegerType::get(ctx, 16); if (isa()) return mlir::IntegerType::get(ctx, 32); if (isa()) return mlir::IntegerType::get(ctx, 64); if (isa()) return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned); if (isa()) return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned); if (isa()) return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned); if (isa()) return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned); if (isa()) return mlir::ComplexType::get(mlir::FloatType::getF32(ctx)); if (isa()) return mlir::ComplexType::get(mlir::FloatType::getF64(ctx)); #define HANDLE_TF_TYPE(tftype, enumerant, name) \ if (isa()) return tftype##Type::get(ctx); #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) // NOLINTNEXTLINE #include "tf_types.def" llvm_unreachable("unexpected tensorflow ref type kind"); } Type TensorFlowTypeWithSubtype::RemoveSubtypes() { MLIRContext* ctx = getContext(); if (isa()) return VariantType::get(ctx); if (isa()) return ResourceType::get(ctx); llvm_unreachable("unexpected tensorflow type with subtypes kind"); } ArrayRef TensorFlowTypeWithSubtype::GetSubtypes() { if (auto variant_type = dyn_cast()) return variant_type.getSubtypes(); if (auto resource_type = dyn_cast()) return resource_type.getSubtypes(); llvm_unreachable("unexpected tensorflow type with subtypes kind"); } // TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have // similar structure that could be extracted into helper method. bool BroadcastCompatible(TypeRange lhs, TypeRange rhs) { if (lhs.size() != rhs.size()) return false; for (auto types : llvm::zip(lhs, rhs)) { // Drop ref types because they don't affect broadcast compatibility. E.g., // `tensor` and `tensor` should be considered broadcast // compatible. auto lhs_type = DropRefType(std::get<0>(types)); auto rhs_type = DropRefType(std::get<1>(types)); // This should be true for all TF ops: auto lhs_tt = lhs_type.dyn_cast(); auto rhs_tt = rhs_type.dyn_cast(); if (!lhs_tt || !rhs_tt) { if (lhs_type != rhs_type) return false; continue; } // Verify matching element types. These should be identical, except for // variant type where unknown subtype is considered compatible with all // subtypes. auto lhs_et = lhs_tt.getElementType(); auto rhs_et = rhs_tt.getElementType(); if (lhs_et != rhs_et) { // If either does not have subtypes, then the element types don't match. auto lhs_wst = lhs_et.dyn_cast(); auto rhs_wst = rhs_et.dyn_cast(); if (!lhs_wst || !rhs_wst) return false; // Consider the subtype of variant types. auto lhs_wst_st = lhs_wst.GetSubtypes(); auto rhs_wst_st = rhs_wst.GetSubtypes(); if (!lhs_wst_st.empty() && !rhs_wst_st.empty()) { for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) { if (!BroadcastCompatible(std::get<0>(subtypes), std::get<1>(subtypes))) return false; } } } auto lhs_rt = lhs_type.dyn_cast(); auto rhs_rt = rhs_type.dyn_cast(); if (!lhs_rt || !rhs_rt) return true; SmallVector shape; return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(), rhs_rt.getShape(), shape); } return true; } // Given two types `a` and `b`, returns a refined type which is cast compatible // with both `a` and `b` and is equal to or more precise than both of them. It // returns empty Type if the input types are not cast compatible. // // The two types are considered cast compatible if they have dynamically equal // shapes and element type. For element types that do not have subtypes, they // must be equal. However for TensorFlow types such as Resource and Variant, // that also have subtypes, we recursively check for subtype compatibility for // Resource types and assume all variant types are cast compatible. If either // one of `a` or `b` have empty subtypes, they are considered cast compatible. // // The returned type is same or more precise than the input types. For example, // if `a` and `b` are cast compatible types tensor<2x?x?xf32> and // tensor respectively, the returned type is tensor<2x4x?xf32>. // // Provides option to ignore ref types on 'a'. This is useful for TF ops that // might allow operands to either be same as result type or be a ref type // corresponding to it. mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, bool may_ignore_ref_type_a) { // Fast path if everything is equal. if (a == b) return b; auto a_tt = a.dyn_cast(); auto b_tt = b.dyn_cast(); // If only one of a or b is a tensor type, they are incompatible. if (static_cast(a_tt) ^ static_cast(b_tt)) return nullptr; // For non-tensor types, we do not need to worry about shape and can return // early. if (!a_tt && !b_tt) { // Remove ref types. if (may_ignore_ref_type_a) { if (auto ref_type = a.dyn_cast()) { a = ref_type.RemoveRef(); if (a == b) return a; } } if (a.getTypeID() != b.getTypeID()) return nullptr; // If either is not a type that contain subtypes then the types are not cast // compatible. auto a_wst = a.dyn_cast(); auto b_wst = b.dyn_cast(); if (!a_wst || !b_wst) return nullptr; // For Variant types we are more permissive right now and accept all pairs // of Variant types. If we are more constrainted and check compatibility of // subtypes, we might reject valid graphs. // TODO(prakalps): Variant doesn't have a subtype, we assign it // one, so we should only assign it one when we know the subtype. Then we // can be more constrained and check subtypes for cast compatibility as // well. if (a.isa()) return a; // For Resource types, we recursively check the subtypes for cast // compatibility, if possible. Otherwise treat them as compatible. auto a_wst_st = a_wst.GetSubtypes(); auto b_wst_st = b_wst.GetSubtypes(); if (a_wst_st.empty() || b_wst_st.empty()) return a; if (a_wst_st.size() != b_wst_st.size()) return nullptr; llvm::SmallVector refined_subtypes; for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) { mlir::Type refined_st = GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes), /*may_ignore_ref_type_a=*/false); if (!refined_st) return nullptr; refined_subtypes.push_back(refined_st.cast()); } return mlir::TF::ResourceType::get(refined_subtypes, a.getContext()); } // For tensor types, check compatibility of both element type and shape. mlir::Type refined_element_ty = GetCastCompatibleType( a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a); if (!refined_element_ty) return nullptr; if (!a_tt.hasRank() && !b_tt.hasRank()) { return mlir::UnrankedTensorType::get(refined_element_ty); } if (!a_tt.hasRank()) { return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty); } if (!b_tt.hasRank()) { return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty); } llvm::SmallVector refined_shape; if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape)) return nullptr; return mlir::RankedTensorType::get(refined_shape, refined_element_ty); } bool HasCompatibleElementTypes(Type lhs, Type rhs, bool may_ignore_ref_type_lhs) { return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr; } bool AreCastCompatible(TypeRange types) { Type common = types.front(); for (auto type : types.drop_front()) { Type refined_type = GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false); if (!refined_type) return false; common = refined_type; } return true; } bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs) { if (lhs.size() != rhs.size()) return false; for (auto pair : llvm::zip(lhs, rhs)) { auto lhs_i = std::get<0>(pair); auto rhs_i = std::get<1>(pair); if (!AreCastCompatible({lhs_i, rhs_i})) return false; } return true; } // Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default // type for a composed type (such as a ref type or a type with subtypes). template Type DropTypeHelper(Type ty) { Type element_ty = getElementTypeOrSelf(ty); auto composed_type = element_ty.dyn_cast(); if (!composed_type) return ty; Type default_ty = GetDefaultTypeOf(composed_type); if (auto ranked_ty = ty.dyn_cast()) { return RankedTensorType::get(ranked_ty.getShape(), default_ty); } else if (ty.dyn_cast()) { return UnrankedTensorType::get(default_ty); } else { return default_ty; } } Type DropSubTypes(Type ty) { return DropTypeHelper(ty); } Type DropRefType(Type ty) { return DropTypeHelper(ty); } Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); } } // namespace TF } // namespace mlir