File size: 10,451 Bytes
be903e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 | // Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#include "tf_dialect.h"
#include <mlir/Dialect/Traits.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h>
#include <mlir/IR/Value.h>
#include <mlir/IR/Verifier.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Parser.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/InliningUtils.h>
#include "tf_attributes.h"
#include "tf_side_effects.h"
#include "tf_traits.h"
namespace mlir {
static LogicalResult Verify(...)
{
return success();
}
static LogicalResult VerifyPartitionedCall(...)
{
return success();
}
static LogicalResult VerifyStridedSliceBase(...)
{
return success();
}
static LogicalResult VerifyUnsortedSegmentReduction(...)
{
return success();
}
namespace TF {
TensorFlowDialect::TensorFlowDialect(MLIRContext* context)
: Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>())
{
addOperations<
#define GET_OP_LIST
#include "tf_all_ops.cc.inc"
>();
addTypes<
#define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
#include "tf_types.def"
>();
// addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
// TFConstantFoldInterface>();
addAttributes<ShapeAttr, FuncAttr>();
// Support unknown operations because not all TensorFlow operations are
// registered.
allowUnknownOperations();
// for (const auto &hook : *TensorFlowDialect::additional_operation_hooks_) {
// hook(*this);
// }
}
namespace {
ShapeAttr ParseShapeAttr(MLIRContext* context, StringRef spec, Location loc)
{
auto emit_error = [&, spec]() {
emitError(loc, "invalid TensorFlow shape attribute: ") << spec;
return nullptr;
};
if (!spec.consume_front("shape<")) return emit_error();
if (spec.consume_front("*>"))
return mlir::TF::ShapeAttr::get(context, llvm::None);
SmallVector<int64_t, 4> shape;
while (!spec.consume_front(">"))
{
int64_t dim;
if (spec.consume_front("?"))
dim = -1;
else if (spec.consumeInteger(10, dim) || dim < 0)
return emit_error();
spec.consume_front("x");
shape.push_back(dim);
}
return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape));
}
// Parses a #tf.func attribute of the following format:
//
// #tf.func<@symbol, {attr = "value"}>
//
// where the first element is a SymbolRefAttr and the second element is a
// DictionaryAttr.
FuncAttr ParseFuncAttr(MLIRContext* context, StringRef spec, Location loc)
{
auto emit_error = [&, spec]() {
emitError(loc, "invalid TensorFlow func attribute: ") << spec;
return nullptr;
};
if (!spec.consume_front("func<")) return emit_error();
size_t func_name_num_read = 0;
Attribute func_name_attr = mlir::parseAttribute(spec, context, func_name_num_read);
if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
return emit_error();
spec = spec.drop_front(func_name_num_read);
if (!spec.consume_front(", ")) return emit_error();
size_t func_attrs_num_read = 0;
Attribute func_attrs_attr = mlir::parseAttribute(spec, context, func_attrs_num_read);
if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
return emit_error();
spec = spec.drop_front(func_attrs_num_read);
if (!spec.consume_front(">")) return emit_error();
return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
func_attrs_attr.cast<DictionaryAttr>());
}
} // namespace
Attribute TensorFlowDialect::parseAttribute(DialectAsmParser& parser,
Type type) const
{
auto spec = parser.getFullSymbolSpec();
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);
if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);
return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
}
// Parses a type registered to this dialect.
Type TensorFlowDialect::parseType(DialectAsmParser& parser) const
{
StringRef data;
if (parser.parseKeyword(&data)) return Type();
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
if (data == name) return tftype##Type::get(getContext());
// Custom TensorFlow types are handled separately at the end as they do partial
// match.
#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
// NOLINTNEXTLINE
#include "tf_types.def"
llvm::SMLoc loc = parser.getNameLoc();
if (data.startswith("resource"))
{
Type ret = ParseResourceType(parser);
if (!ret) parser.emitError(loc, "invalid resource type");
return ret;
}
if (data.startswith("variant"))
{
Type ret = ParseVariantType(parser);
if (!ret) parser.emitError(loc, "invalid variant type");
return ret;
}
return (parser.emitError(loc, "unknown TensorFlow type: " + data), nullptr);
}
namespace {
template<typename TypeWithSubtype>
Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser)
{
// Default type without inferred subtypes.
if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
// Most types with subtypes have only one subtype.
SmallVector<TensorType, 1> subtypes;
do
{
TensorType tensor_ty;
if (parser.parseType(tensor_ty)) return Type();
// Each of the subtypes should be a valid TensorFlow type.
// TODO(jpienaar): Remove duplication.
if (!IsValidTFTensorType(tensor_ty))
{
parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty;
return Type();
}
subtypes.push_back(tensor_ty);
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseGreater()) return Type();
return TypeWithSubtype::get(subtypes, context);
}
} // anonymous namespace
Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser) const
{
return ParseTypeWithSubtype<ResourceType>(getContext(), parser);
}
Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser) const
{
return ParseTypeWithSubtype<VariantType>(getContext(), parser);
}
Operation* TensorFlowDialect::materializeConstant(OpBuilder& builder,
Attribute value, Type type,
Location loc)
{
return builder.create<ConstOp>(loc, type, value);
}
// Builds a constant op with the specified attribute `value`. The result
// op's type is deduced from `value`; if `value` is of scalar type,
// wraps it up with a tensor type of empty shape.
// TODO(jpienaar): This one differs from the autogenerated one as it takes an
// attribute but always creates an ElementsAttr internally.
void ConstOp::build(OpBuilder& builder, OperationState& result,
Attribute value)
{
ShapedType type;
if (auto elem_attr = value.dyn_cast<ElementsAttr>())
{
return ConstOp::build(builder, result, elem_attr);
}
else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>())
{
// All TensorFlow types must be tensor types. In the build() method,
// we want to provide more flexibility by allowing attributes of scalar
// types. But we need to wrap it up with ElementsAttr to construct
// valid TensorFlow constants.
type = RankedTensorType::get(/*shape=*/ {}, value.getType());
return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
}
// TODO(jpienaar): support other TensorFlow specific types.
llvm_unreachable("unsupported attribute type for building tf.Const");
}
void ConstOp::build(OpBuilder& builder, OperationState& result, Type type,
Attribute value)
{
// Handle the case where the type and value are already tensors.
if (type.isa<TensorType>() && value.isa<ElementsAttr>())
{
result.addTypes(type);
result.addAttribute("value", value);
return;
}
// Otherwise, default to the attribute builder.
ConstOp::build(builder, result, value);
assert(type == result.types[0] && "type mismatch in construction");
}
Region& WhileRegionOp::getLoopBody()
{
return body();
}
bool WhileRegionOp::isDefinedOutsideOfLoop(Value value)
{
// If the Op defining the value exists and the defining op is outside the
// scope of this WhileRegion, then we can infer that its defined outside.
// The defining Op is outside the scope of this WhileRegion if this
// WhileRegionOp is not an ancestor of the defining op in the parent chain.
Operation* def_op = value.getDefiningOp();
return def_op && !getOperation()->isAncestor(def_op);
}
LogicalResult WhileRegionOp::moveOutOfLoop(
llvm::ArrayRef<mlir::Operation*> ops)
{
// Move the hoisted value to just before the while.
Operation* while_op = this->getOperation();
for (auto op : ops) op->moveBefore(while_op);
return success();
}
} // namespace TF
} // namespace mlir
#define GET_OP_CLASSES
#include "tf_all_ops.cc.inc"
|