| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #include "tf_attributes.h" |
|
|
| namespace mlir { |
| namespace TF { |
|
|
| namespace detail { |
|
|
| |
| struct ShapeAttrStorage : public AttributeStorage |
| { |
| using KeyTy = std::pair<ArrayRef<int64_t>, bool>; |
|
|
| explicit ShapeAttrStorage(ArrayRef<int64_t> shape, bool unranked = false) |
| : shape(shape), unranked(unranked) |
| { |
| } |
|
|
| bool operator==(const KeyTy& key) const |
| { |
| return key == KeyTy(shape, unranked); |
| } |
| static unsigned hashKey(const KeyTy& key) |
| { |
| return llvm::hash_combine(key.first, static_cast<char>(key.second)); |
| } |
|
|
| |
| static ShapeAttrStorage* construct(mlir::AttributeStorageAllocator& allocator, |
| const KeyTy& key) |
| { |
| return new (allocator.allocate<ShapeAttrStorage>()) |
| ShapeAttrStorage(allocator.copyInto(key.first), key.second); |
| } |
|
|
| ArrayRef<int64_t> shape; |
| bool unranked = false; |
| }; |
|
|
| |
| struct FuncAttrStorage : public AttributeStorage |
| { |
| using KeyTy = std::pair<Attribute, Attribute>; |
|
|
| explicit FuncAttrStorage(Attribute name, Attribute attrs) |
| : name(name), attrs(attrs) |
| { |
| } |
|
|
| bool operator==(const KeyTy& key) const |
| { |
| return key == KeyTy(name, attrs); |
| } |
| static unsigned hashKey(const KeyTy& key) |
| { |
| return llvm::hash_combine(key.first, key.second); |
| } |
|
|
| static FuncAttrStorage* construct(mlir::AttributeStorageAllocator& allocator, |
| const KeyTy& key) |
| { |
| return new (allocator.allocate<FuncAttrStorage>()) |
| FuncAttrStorage(key.first, key.second); |
| } |
|
|
| Attribute name; |
| Attribute attrs; |
| }; |
|
|
| } |
|
|
| |
| ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, |
| llvm::Optional<ArrayRef<int64_t> > shape) |
| { |
| if (shape) return Base::get(context, *shape, false); |
|
|
| return Base::get(context, ArrayRef<int64_t>(), true); |
| } |
|
|
| |
| ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, ShapedType shaped_type) |
| { |
| if (shaped_type.hasRank()) |
| return Base::get(context, shaped_type.getShape(), false); |
|
|
| return Base::get(context, ArrayRef<int64_t>(), true); |
| } |
|
|
| llvm::Optional<ArrayRef<int64_t> > ShapeAttr::getValue() const |
| { |
| if (hasRank()) return getShape(); |
| return llvm::None; |
| } |
|
|
| bool ShapeAttr::hasRank() const |
| { |
| return !getImpl()->unranked; |
| } |
|
|
| int64_t ShapeAttr::getRank() const |
| { |
| assert(hasRank()); |
| return getImpl()->shape.size(); |
| } |
|
|
| ArrayRef<int64_t> ShapeAttr::getShape() const |
| { |
| assert(hasRank()); |
| return getImpl()->shape; |
| } |
|
|
| bool ShapeAttr::hasStaticShape() const |
| { |
| if (!hasRank()) return false; |
|
|
| for (auto dim : getShape()) |
| { |
| if (dim < 0) return false; |
| } |
|
|
| return true; |
| } |
|
|
| FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name, |
| DictionaryAttr attr) |
| { |
| auto symbol = SymbolRefAttr::get(context, name); |
| return Base::get(context, symbol, attr); |
| } |
|
|
| FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol, |
| DictionaryAttr attr) |
| { |
| return Base::get(context, symbol, attr); |
| } |
|
|
| SymbolRefAttr FuncAttr::GetName() const |
| { |
| return getImpl()->name.cast<SymbolRefAttr>(); |
| } |
|
|
| DictionaryAttr FuncAttr::GetAttrs() const |
| { |
| return getImpl()->attrs.cast<DictionaryAttr>(); |
| } |
|
|
| } |
| } |
|
|