| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include "tf_dialect.h" |
| |
|
| | #include <mlir/Dialect/Traits.h> |
| | #include <mlir/IR/Attributes.h> |
| | #include <mlir/IR/Builders.h> |
| | #include <mlir/IR/Dialect.h> |
| | #include <mlir/IR/DialectImplementation.h> |
| | #include <mlir/IR/Location.h> |
| | #include <mlir/IR/Matchers.h> |
| | #include <mlir/IR/MLIRContext.h> |
| | #include <mlir/IR/OpDefinition.h> |
| | #include <mlir/IR/OpImplementation.h> |
| | #include <mlir/IR/Operation.h> |
| | #include <mlir/IR/OperationSupport.h> |
| | #include <mlir/IR/PatternMatch.h> |
| | #include <mlir/IR/TypeUtilities.h> |
| | #include <mlir/IR/Types.h> |
| | #include <mlir/IR/Value.h> |
| | #include <mlir/IR/Verifier.h> |
| | #include <mlir/Interfaces/CallInterfaces.h> |
| | #include <mlir/Interfaces/DerivedAttributeOpInterface.h> |
| | #include <mlir/Interfaces/InferTypeOpInterface.h> |
| | #include <mlir/Interfaces/LoopLikeInterface.h> |
| | #include <mlir/Interfaces/SideEffectInterfaces.h> |
| | #include <mlir/Parser.h> |
| | #include <mlir/Support/LogicalResult.h> |
| | #include <mlir/Transforms/InliningUtils.h> |
| |
|
| | #include "tf_attributes.h" |
| | #include "tf_side_effects.h" |
| | #include "tf_traits.h" |
| |
|
| | namespace mlir { |
| |
|
| | static LogicalResult Verify(...) |
| | { |
| | return success(); |
| | } |
| | static LogicalResult VerifyPartitionedCall(...) |
| | { |
| | return success(); |
| | } |
| | static LogicalResult VerifyStridedSliceBase(...) |
| | { |
| | return success(); |
| | } |
| | static LogicalResult VerifyUnsortedSegmentReduction(...) |
| | { |
| | return success(); |
| | } |
| |
|
| | namespace TF { |
| |
|
| | TensorFlowDialect::TensorFlowDialect(MLIRContext* context) |
| | : Dialect("tf", context, TypeID::get<TensorFlowDialect>()) |
| | { |
| | addOperations< |
| | #define GET_OP_LIST |
| | #include "tf_all_ops.cc.inc" |
| | >(); |
| | addTypes< |
| | #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type, |
| | #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type |
| | #include "tf_types.def" |
| | >(); |
| | |
| | |
| | addAttributes<ShapeAttr, FuncAttr>(); |
| |
|
| | |
| | |
| | allowUnknownOperations(); |
| |
|
| | |
| | |
| | |
| | } |
| |
|
| | namespace { |
| |
|
| | ShapeAttr ParseShapeAttr(MLIRContext* context, StringRef spec, Location loc) |
| | { |
| | auto emit_error = [&, spec]() { |
| | emitError(loc, "invalid TensorFlow shape attribute: ") << spec; |
| | return nullptr; |
| | }; |
| |
|
| | if (!spec.consume_front("shape<")) return emit_error(); |
| |
|
| | if (spec.consume_front("*>")) |
| | return mlir::TF::ShapeAttr::get(context, llvm::None); |
| |
|
| | SmallVector<int64_t, 4> shape; |
| | while (!spec.consume_front(">")) |
| | { |
| | int64_t dim; |
| |
|
| | if (spec.consume_front("?")) |
| | dim = -1; |
| | else if (spec.consumeInteger(10, dim) || dim < 0) |
| | return emit_error(); |
| |
|
| | spec.consume_front("x"); |
| |
|
| | shape.push_back(dim); |
| | } |
| |
|
| | return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape)); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | FuncAttr ParseFuncAttr(MLIRContext* context, StringRef spec, Location loc) |
| | { |
| | auto emit_error = [&, spec]() { |
| | emitError(loc, "invalid TensorFlow func attribute: ") << spec; |
| | return nullptr; |
| | }; |
| |
|
| | if (!spec.consume_front("func<")) return emit_error(); |
| |
|
| | size_t func_name_num_read = 0; |
| | Attribute func_name_attr = mlir::parseAttribute(spec, context, func_name_num_read); |
| | if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>()) |
| | return emit_error(); |
| | spec = spec.drop_front(func_name_num_read); |
| |
|
| | if (!spec.consume_front(", ")) return emit_error(); |
| |
|
| | size_t func_attrs_num_read = 0; |
| | Attribute func_attrs_attr = mlir::parseAttribute(spec, context, func_attrs_num_read); |
| | if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>()) |
| | return emit_error(); |
| | spec = spec.drop_front(func_attrs_num_read); |
| |
|
| | if (!spec.consume_front(">")) return emit_error(); |
| |
|
| | return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(), |
| | func_attrs_attr.cast<DictionaryAttr>()); |
| | } |
| |
|
| | } |
| |
|
| | Attribute TensorFlowDialect::parseAttribute(DialectAsmParser& parser, |
| | Type type) const |
| | { |
| | auto spec = parser.getFullSymbolSpec(); |
| | Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); |
| |
|
| | if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc); |
| |
|
| | if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc); |
| |
|
| | return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr); |
| | } |
| |
|
| | |
| | Type TensorFlowDialect::parseType(DialectAsmParser& parser) const |
| | { |
| | StringRef data; |
| | if (parser.parseKeyword(&data)) return Type(); |
| |
|
| | #define HANDLE_TF_TYPE(tftype, enumerant, name) \ |
| | if (data == name) return tftype##Type::get(getContext()); |
| | |
| | |
| | #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) |
| | |
| | #include "tf_types.def" |
| |
|
| | llvm::SMLoc loc = parser.getNameLoc(); |
| | if (data.startswith("resource")) |
| | { |
| | Type ret = ParseResourceType(parser); |
| | if (!ret) parser.emitError(loc, "invalid resource type"); |
| | return ret; |
| | } |
| | if (data.startswith("variant")) |
| | { |
| | Type ret = ParseVariantType(parser); |
| | if (!ret) parser.emitError(loc, "invalid variant type"); |
| | return ret; |
| | } |
| | return (parser.emitError(loc, "unknown TensorFlow type: " + data), nullptr); |
| | } |
| |
|
| | namespace { |
| | template<typename TypeWithSubtype> |
| | Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser) |
| | { |
| | |
| | if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context); |
| |
|
| | |
| | SmallVector<TensorType, 1> subtypes; |
| | do |
| | { |
| | TensorType tensor_ty; |
| | if (parser.parseType(tensor_ty)) return Type(); |
| |
|
| | |
| | |
| | if (!IsValidTFTensorType(tensor_ty)) |
| | { |
| | parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty; |
| | return Type(); |
| | } |
| | subtypes.push_back(tensor_ty); |
| | } while (succeeded(parser.parseOptionalComma())); |
| |
|
| | if (parser.parseGreater()) return Type(); |
| |
|
| | return TypeWithSubtype::get(subtypes, context); |
| | } |
| | } |
| |
|
| | Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser) const |
| | { |
| | return ParseTypeWithSubtype<ResourceType>(getContext(), parser); |
| | } |
| |
|
| | Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser) const |
| | { |
| | return ParseTypeWithSubtype<VariantType>(getContext(), parser); |
| | } |
| |
|
| | Operation* TensorFlowDialect::materializeConstant(OpBuilder& builder, |
| | Attribute value, Type type, |
| | Location loc) |
| | { |
| | return builder.create<ConstOp>(loc, type, value); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | void ConstOp::build(OpBuilder& builder, OperationState& result, |
| | Attribute value) |
| | { |
| | ShapedType type; |
| | if (auto elem_attr = value.dyn_cast<ElementsAttr>()) |
| | { |
| | return ConstOp::build(builder, result, elem_attr); |
| | } |
| | else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) |
| | { |
| | |
| | |
| | |
| | |
| | type = RankedTensorType::get( {}, value.getType()); |
| | return ConstOp::build(builder, result, DenseElementsAttr::get(type, value)); |
| | } |
| | |
| | llvm_unreachable("unsupported attribute type for building tf.Const"); |
| | } |
| |
|
| | void ConstOp::build(OpBuilder& builder, OperationState& result, Type type, |
| | Attribute value) |
| | { |
| | |
| | if (type.isa<TensorType>() && value.isa<ElementsAttr>()) |
| | { |
| | result.addTypes(type); |
| | result.addAttribute("value", value); |
| | return; |
| | } |
| |
|
| | |
| | ConstOp::build(builder, result, value); |
| | assert(type == result.types[0] && "type mismatch in construction"); |
| | } |
| |
|
| | Region& WhileRegionOp::getLoopBody() |
| | { |
| | return body(); |
| | } |
| |
|
| | bool WhileRegionOp::isDefinedOutsideOfLoop(Value value) |
| | { |
| | |
| | |
| | |
| | |
| | Operation* def_op = value.getDefiningOp(); |
| | return def_op && !getOperation()->isAncestor(def_op); |
| | } |
| |
|
| | LogicalResult WhileRegionOp::moveOutOfLoop( |
| | llvm::ArrayRef<mlir::Operation*> ops) |
| | { |
| | |
| | Operation* while_op = this->getOperation(); |
| | for (auto op : ops) op->moveBefore(while_op); |
| | return success(); |
| | } |
| |
|
| | } |
| |
|
| | } |
| |
|
| | #define GET_OP_CLASSES |
| | #include "tf_all_ops.cc.inc" |
| |
|