| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include "tf_types.h" |
| |
|
| | #include "llvm/Support/ErrorHandling.h" |
| | #include "mlir/Dialect/Traits.h" |
| | #include "mlir/IR/BuiltinTypes.h" |
| | #include "mlir/IR/Dialect.h" |
| | #include "mlir/IR/TypeUtilities.h" |
| |
|
| | namespace { |
| | |
| | |
| | llvm::Optional<llvm::ArrayRef<int64_t> > GetShape(mlir::Value value) |
| | { |
| | auto shaped_type = value.getType().cast<mlir::ShapedType>(); |
| | if (shaped_type.hasRank()) return shaped_type.getShape(); |
| | return llvm::None; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape, |
| | llvm::ArrayRef<int64_t> b_shape, |
| | llvm::SmallVectorImpl<int64_t>* refined_shape) |
| | { |
| | if (a_shape.size() != b_shape.size()) return false; |
| | int64_t rank = a_shape.size(); |
| | refined_shape->reserve(rank); |
| | for (auto dims : llvm::zip(a_shape, b_shape)) |
| | { |
| | int64_t dim1 = std::get<0>(dims); |
| | int64_t dim2 = std::get<1>(dims); |
| |
|
| | if (mlir::ShapedType::isDynamic(dim1)) |
| | { |
| | refined_shape->push_back(dim2); |
| | continue; |
| | } |
| | if (mlir::ShapedType::isDynamic(dim2)) |
| | { |
| | refined_shape->push_back(dim1); |
| | continue; |
| | } |
| | if (dim1 == dim2) |
| | { |
| | refined_shape->push_back(dim1); |
| | continue; |
| | } |
| | return false; |
| | } |
| | return true; |
| | } |
| |
|
| | } |
| |
|
| | namespace mlir { |
| | namespace TF { |
| | |
| | |
| | |
| |
|
| | OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it) |
| | : llvm::mapped_iterator<Operation::operand_iterator, |
| | llvm::Optional<ArrayRef<int64_t> > (*)(Value)>( |
| | it, &GetShape) |
| | { |
| | } |
| |
|
| | ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it) |
| | : llvm::mapped_iterator<Operation::result_iterator, |
| | llvm::Optional<ArrayRef<int64_t> > (*)(Value)>( |
| | it, &GetShape) |
| | { |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | bool TensorFlowType::classof(Type type) |
| | { |
| | return type.getDialect().getNamespace() == "tf"; |
| | } |
| | bool TensorFlowRefType::classof(Type type) |
| | { |
| | return type.isa< |
| | #define HANDLE_TF_TYPE(tftype, enumerant, name) |
| | #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type, |
| | #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type |
| | |
| | #include "tf_types.def" |
| | >(); |
| | } |
| | bool TensorFlowTypeWithSubtype::classof(Type type) |
| | { |
| | return type.isa<ResourceType, VariantType>(); |
| | } |
| |
|
| | TensorFlowType TensorFlowRefType::get(Type type) |
| | { |
| | MLIRContext* ctx = type.getContext(); |
| | type = getElementTypeOrSelf(type); |
| | if (type.isF16()) |
| | { |
| | return HalfRefType::get(ctx); |
| | } |
| | else if (type.isF32()) |
| | { |
| | return FloatRefType::get(ctx); |
| | } |
| | else if (type.isF64()) |
| | { |
| | return DoubleRefType::get(ctx); |
| | } |
| | else if (type.isBF16()) |
| | { |
| | return Bfloat16RefType::get(ctx); |
| | } |
| | else if (auto complex_type = type.dyn_cast<ComplexType>()) |
| | { |
| | Type etype = complex_type.getElementType(); |
| | if (etype.isF32()) |
| | { |
| | return Complex64RefType::get(ctx); |
| | } |
| | else if (etype.isF64()) |
| | { |
| | return Complex128RefType::get(ctx); |
| | } |
| | llvm_unreachable("unexpected complex type"); |
| | } |
| | else if (auto itype = type.dyn_cast<IntegerType>()) |
| | { |
| | switch (itype.getWidth()) |
| | { |
| | case 1: |
| | return BoolRefType::get(ctx); |
| | case 8: |
| | return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx)) |
| | : Int8RefType::get(ctx); |
| | case 16: |
| | return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx)) |
| | : Int16RefType::get(ctx); |
| | case 32: |
| | return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx)) |
| | : Int32RefType::get(ctx); |
| | case 64: |
| | return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx)) |
| | : Int64RefType::get(ctx); |
| | default: |
| | llvm_unreachable("unexpected integer type"); |
| | } |
| | } |
| | #define HANDLE_TF_TYPE(tftype, enumerant, name) \ |
| | if (auto derived_ty = type.dyn_cast<tftype##Type>()) \ |
| | return tftype##RefType::get(ctx); |
| |
|
| | #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) |
| | |
| | #include "tf_types.def" |
| | llvm_unreachable("unexpected type kind"); |
| | } |
| |
|
| | Type TensorFlowRefType::RemoveRef() |
| | { |
| | MLIRContext* ctx = getContext(); |
| | if (isa<HalfRefType>()) return mlir::FloatType::getF16(ctx); |
| | if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx); |
| | if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx); |
| | if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx); |
| | if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1); |
| | if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8); |
| | if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16); |
| | if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32); |
| | if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64); |
| | if (isa<Uint8RefType>()) |
| | return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned); |
| | if (isa<Uint16RefType>()) |
| | return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned); |
| | if (isa<Uint32RefType>()) |
| | return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned); |
| | if (isa<Uint64RefType>()) |
| | return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned); |
| | if (isa<Complex64RefType>()) |
| | return mlir::ComplexType::get(mlir::FloatType::getF32(ctx)); |
| | if (isa<Complex128RefType>()) |
| | return mlir::ComplexType::get(mlir::FloatType::getF64(ctx)); |
| | #define HANDLE_TF_TYPE(tftype, enumerant, name) \ |
| | if (isa<tftype##RefType>()) return tftype##Type::get(ctx); |
| |
|
| | #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) |
| | |
| | #include "tf_types.def" |
| | llvm_unreachable("unexpected tensorflow ref type kind"); |
| | } |
| |
|
| | Type TensorFlowTypeWithSubtype::RemoveSubtypes() |
| | { |
| | MLIRContext* ctx = getContext(); |
| | if (isa<VariantType>()) return VariantType::get(ctx); |
| | if (isa<ResourceType>()) return ResourceType::get(ctx); |
| | llvm_unreachable("unexpected tensorflow type with subtypes kind"); |
| | } |
| |
|
| | ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() |
| | { |
| | if (auto variant_type = dyn_cast<VariantType>()) |
| | return variant_type.getSubtypes(); |
| | if (auto resource_type = dyn_cast<ResourceType>()) |
| | return resource_type.getSubtypes(); |
| | llvm_unreachable("unexpected tensorflow type with subtypes kind"); |
| | } |
| |
|
| | |
| | |
| | bool BroadcastCompatible(TypeRange lhs, TypeRange rhs) |
| | { |
| | if (lhs.size() != rhs.size()) return false; |
| | for (auto types : llvm::zip(lhs, rhs)) |
| | { |
| | |
| | |
| | |
| | auto lhs_type = DropRefType(std::get<0>(types)); |
| | auto rhs_type = DropRefType(std::get<1>(types)); |
| |
|
| | |
| | auto lhs_tt = lhs_type.dyn_cast<TensorType>(); |
| | auto rhs_tt = rhs_type.dyn_cast<TensorType>(); |
| | if (!lhs_tt || !rhs_tt) |
| | { |
| | if (lhs_type != rhs_type) return false; |
| | continue; |
| | } |
| |
|
| | |
| | |
| | |
| | auto lhs_et = lhs_tt.getElementType(); |
| | auto rhs_et = rhs_tt.getElementType(); |
| | if (lhs_et != rhs_et) |
| | { |
| | |
| | auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>(); |
| | auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>(); |
| | if (!lhs_wst || !rhs_wst) return false; |
| |
|
| | |
| | auto lhs_wst_st = lhs_wst.GetSubtypes(); |
| | auto rhs_wst_st = rhs_wst.GetSubtypes(); |
| | if (!lhs_wst_st.empty() && !rhs_wst_st.empty()) |
| | { |
| | for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) |
| | { |
| | if (!BroadcastCompatible(std::get<0>(subtypes), |
| | std::get<1>(subtypes))) |
| | return false; |
| | } |
| | } |
| | } |
| |
|
| | auto lhs_rt = lhs_type.dyn_cast<RankedTensorType>(); |
| | auto rhs_rt = rhs_type.dyn_cast<RankedTensorType>(); |
| | if (!lhs_rt || !rhs_rt) return true; |
| | SmallVector<int64_t, 4> shape; |
| | return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(), |
| | rhs_rt.getShape(), shape); |
| | } |
| | return true; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, |
| | bool may_ignore_ref_type_a) |
| | { |
| | |
| | if (a == b) return b; |
| |
|
| | auto a_tt = a.dyn_cast<mlir::TensorType>(); |
| | auto b_tt = b.dyn_cast<mlir::TensorType>(); |
| |
|
| | |
| | if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr; |
| |
|
| | |
| | |
| | if (!a_tt && !b_tt) |
| | { |
| | |
| | if (may_ignore_ref_type_a) |
| | { |
| | if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>()) |
| | { |
| | a = ref_type.RemoveRef(); |
| | if (a == b) return a; |
| | } |
| | } |
| | if (a.getTypeID() != b.getTypeID()) return nullptr; |
| |
|
| | |
| | |
| | auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>(); |
| | auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>(); |
| | if (!a_wst || !b_wst) return nullptr; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if (a.isa<mlir::TF::VariantType>()) return a; |
| |
|
| | |
| | |
| | auto a_wst_st = a_wst.GetSubtypes(); |
| | auto b_wst_st = b_wst.GetSubtypes(); |
| | if (a_wst_st.empty() || b_wst_st.empty()) return a; |
| | if (a_wst_st.size() != b_wst_st.size()) return nullptr; |
| | llvm::SmallVector<mlir::TensorType, 4> refined_subtypes; |
| | for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) |
| | { |
| | mlir::Type refined_st = GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes), |
| | false); |
| | if (!refined_st) return nullptr; |
| | refined_subtypes.push_back(refined_st.cast<mlir::TensorType>()); |
| | } |
| |
|
| | return mlir::TF::ResourceType::get(refined_subtypes, a.getContext()); |
| | } |
| |
|
| | |
| | mlir::Type refined_element_ty = GetCastCompatibleType( |
| | a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a); |
| | if (!refined_element_ty) return nullptr; |
| |
|
| | if (!a_tt.hasRank() && !b_tt.hasRank()) |
| | { |
| | return mlir::UnrankedTensorType::get(refined_element_ty); |
| | } |
| | if (!a_tt.hasRank()) |
| | { |
| | return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty); |
| | } |
| | if (!b_tt.hasRank()) |
| | { |
| | return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty); |
| | } |
| |
|
| | llvm::SmallVector<int64_t, 8> refined_shape; |
| | if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape)) |
| | return nullptr; |
| |
|
| | return mlir::RankedTensorType::get(refined_shape, refined_element_ty); |
| | } |
| |
|
| | bool HasCompatibleElementTypes(Type lhs, Type rhs, |
| | bool may_ignore_ref_type_lhs) |
| | { |
| | return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr; |
| | } |
| |
|
| | bool AreCastCompatible(TypeRange types) |
| | { |
| | Type common = types.front(); |
| | for (auto type : types.drop_front()) |
| | { |
| | Type refined_type = GetCastCompatibleType(common, type, false); |
| | if (!refined_type) return false; |
| | common = refined_type; |
| | } |
| | return true; |
| | } |
| |
|
| | bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs) |
| | { |
| | if (lhs.size() != rhs.size()) return false; |
| | for (auto pair : llvm::zip(lhs, rhs)) |
| | { |
| | auto lhs_i = std::get<0>(pair); |
| | auto rhs_i = std::get<1>(pair); |
| | if (!AreCastCompatible({lhs_i, rhs_i})) return false; |
| | } |
| | return true; |
| | } |
| |
|
| | |
| | |
| | template<typename ComposedType> |
| | Type DropTypeHelper(Type ty) |
| | { |
| | Type element_ty = getElementTypeOrSelf(ty); |
| | auto composed_type = element_ty.dyn_cast<ComposedType>(); |
| | if (!composed_type) return ty; |
| |
|
| | Type default_ty = GetDefaultTypeOf(composed_type); |
| | if (auto ranked_ty = ty.dyn_cast<RankedTensorType>()) |
| | { |
| | return RankedTensorType::get(ranked_ty.getShape(), default_ty); |
| | } |
| | else if (ty.dyn_cast<UnrankedTensorType>()) |
| | { |
| | return UnrankedTensorType::get(default_ty); |
| | } |
| | else |
| | { |
| | return default_ty; |
| | } |
| | } |
| |
|
| | Type DropSubTypes(Type ty) |
| | { |
| | return DropTypeHelper<TF::TensorFlowTypeWithSubtype>(ty); |
| | } |
| |
|
| | Type DropRefType(Type ty) |
| | { |
| | return DropTypeHelper<TF::TensorFlowRefType>(ty); |
| | } |
| |
|
| | Type DropRefAndSubTypes(Type ty) |
| | { |
| | return DropRefType(DropSubTypes(ty)); |
| | } |
| |
|
| | } |
| | } |
| |
|