Spaces:
Build error
Build error
| /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| ==============================================================================*/ | |
| namespace tensorflow { | |
| Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size, | |
| int64 dilation_rate, int64 stride, | |
| Padding padding_type, int64* output_size, | |
| int64* padding_before, | |
| int64* padding_after) { | |
| if (stride <= 0) { | |
| return errors::InvalidArgument("Stride must be > 0, but got ", stride); | |
| } | |
| if (dilation_rate < 1) { | |
| return errors::InvalidArgument("Dilation rate must be >= 1, but got ", | |
| dilation_rate); | |
| } | |
| // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2. | |
| int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1; | |
| switch (padding_type) { | |
| case Padding::VALID: | |
| *output_size = (input_size - effective_filter_size + stride) / stride; | |
| *padding_before = *padding_after = 0; | |
| break; | |
| case Padding::SAME: | |
| *output_size = (input_size + stride - 1) / stride; | |
| const int64 padding_needed = | |
| std::max(0LL, (*output_size - 1) * stride + effective_filter_size - | |
| input_size); | |
| // For odd values of total padding, add more padding at the 'right' | |
| // side of the given dimension. | |
| *padding_before = padding_needed / 2; | |
| *padding_after = padding_needed - *padding_before; | |
| break; | |
| } | |
| if (*output_size < 0) { | |
| return errors::InvalidArgument("computed output size would be negative"); | |
| } | |
| return Status::OK(); | |
| } | |
| Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size, | |
| int64 stride, Padding padding_type, | |
| int64* output_size, int64* padding_before, | |
| int64* padding_after) { | |
| return GetWindowedOutputSizeVerboseV2(input_size, filter_size, | |
| /*dilation_rate=*/1, stride, | |
| padding_type, output_size, | |
| padding_before, padding_after); | |
| } | |
| Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride, | |
| Padding padding_type, int64* output_size, | |
| int64* padding_size) { | |
| int64 padding_after_unused; | |
| return GetWindowedOutputSizeVerbose(input_size, filter_size, stride, | |
| padding_type, output_size, padding_size, | |
| &padding_after_unused); | |
| } | |
| Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size, | |
| int64 dilation_rate, int64 stride, | |
| Padding padding_type, int64* output_size, | |
| int64* padding_size) { | |
| int64 padding_after_unused; | |
| return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate, | |
| stride, padding_type, output_size, | |
| padding_size, &padding_after_unused); | |
| } | |
| Status Get3dOutputSize(const std::array<int64, 3>& input, | |
| const std::array<int64, 3>& window, | |
| const std::array<int64, 3>& strides, | |
| Padding padding_type, std::array<int64, 3>* output_ptr, | |
| std::array<int64, 3>* padding_ptr) { | |
| for (size_t i = 0; i < input.size(); ++i) { | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i], | |
| padding_type, &(*output_ptr)[i], | |
| &(*padding_ptr)[i])); | |
| } | |
| return Status::OK(); | |
| } | |
| Status Get3dOutputSizeV2(const std::array<int64, 3>& input, | |
| const std::array<int64, 3>& window, | |
| const std::array<int64, 3>& dilations, | |
| const std::array<int64, 3>& strides, | |
| Padding padding_type, std::array<int64, 3>* output_ptr, | |
| std::array<int64, 3>* padding_ptr) { | |
| for (size_t i = 0; i < input.size(); ++i) { | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2( | |
| input[i], window[i], dilations[i], strides[i], padding_type, | |
| &(*output_ptr)[i], &(*padding_ptr)[i])); | |
| } | |
| return Status::OK(); | |
| } | |
| namespace shape_inference { | |
| // The V2 version computes windowed output size with arbitrary dilation_rate, | |
| // while the original version only handles the cases where dilation_rates equal | |
| // to 1. | |
| Status GetWindowedOutputSizeFromDimsV2( | |
| shape_inference::InferenceContext* c, | |
| shape_inference::DimensionHandle input_size, | |
| shape_inference::DimensionOrConstant filter_size, int64 dilation_rate, | |
| int64 stride, Padding padding_type, | |
| shape_inference::DimensionHandle* output_size) { | |
| if (stride <= 0) { | |
| return errors::InvalidArgument("Stride must be > 0, but got ", stride); | |
| } | |
| if (dilation_rate < 1) { | |
| return errors::InvalidArgument("Dilation rate must be >= 1, but got ", | |
| dilation_rate); | |
| } | |
| // See also the parallel implementation in GetWindowedOutputSizeVerbose. | |
| switch (padding_type) { | |
| case Padding::VALID: | |
| if (dilation_rate > 1) { | |
| DimensionHandle window_size; | |
| TF_RETURN_IF_ERROR( | |
| c->Subtract(c->MakeDim(filter_size), 1, &window_size)); | |
| TF_RETURN_IF_ERROR( | |
| c->Multiply(window_size, dilation_rate, &window_size)); | |
| TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size)); | |
| TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size)); | |
| } else { | |
| TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size)); | |
| } | |
| TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size)); | |
| TF_RETURN_IF_ERROR(c->Divide(*output_size, stride, | |
| /*evenly_divisible=*/false, output_size)); | |
| break; | |
| case Padding::SAME: | |
| TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size)); | |
| TF_RETURN_IF_ERROR(c->Divide(*output_size, stride, | |
| /*evenly_divisible=*/false, output_size)); | |
| break; | |
| } | |
| return Status::OK(); | |
| } | |
| Status GetWindowedOutputSizeFromDims( | |
| shape_inference::InferenceContext* c, | |
| shape_inference::DimensionHandle input_size, | |
| shape_inference::DimensionOrConstant filter_size, int64 stride, | |
| Padding padding_type, shape_inference::DimensionHandle* output_size) { | |
| return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size, | |
| /*dilation_rate=*/1, stride, | |
| padding_type, output_size); | |
| } | |
| Status UnchangedShape(shape_inference::InferenceContext* c) { | |
| c->set_output(0, c->input(0)); | |
| return Status::OK(); | |
| } | |
| Status MatMulShape(shape_inference::InferenceContext* c) { | |
| ShapeHandle a; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a)); | |
| ShapeHandle b; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b)); | |
| bool transpose_a, transpose_b; | |
| TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a)); | |
| TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b)); | |
| DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0); | |
| DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1); | |
| // Validate that the inner shapes are compatible. | |
| DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1); | |
| DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0); | |
| DimensionHandle merged; | |
| TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged)); | |
| c->set_output(0, c->Matrix(output_rows, output_cols)); | |
| return Status::OK(); | |
| } | |
| Status BiasAddShape(shape_inference::InferenceContext* c) { | |
| ShapeHandle input_shape; | |
| // Fetch the data_format attribute, which may not exist. | |
| string data_format; | |
| Status s = c->GetAttr("data_format", &data_format); | |
| if (s.ok() && data_format == "NCHW") { | |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); | |
| } else { | |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); | |
| } | |
| ShapeHandle bias_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape)); | |
| DimensionHandle bias_dim = c->Dim(bias_shape, 0); | |
| // If rank unknown, return unknown shape. | |
| if (!c->RankKnown(input_shape)) { | |
| c->set_output(0, c->UnknownShape()); | |
| return Status::OK(); | |
| } | |
| // Output has the same shape as the input, and matches the length of | |
| // the bias in its bias dimension. | |
| ShapeHandle output_shape; | |
| if (s.ok() && data_format == "NCHW") { | |
| // Merge the length of bias_shape into the third to last dimension | |
| ShapeHandle first; | |
| TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -3, &first)); | |
| ShapeHandle last; | |
| TF_RETURN_IF_ERROR(c->Subshape(input_shape, -2, &last)); | |
| DimensionHandle input_bias_dim = c->Dim(input_shape, -3); | |
| DimensionHandle merged_bias_dim; | |
| TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim)); | |
| ShapeHandle merged_bias = c->Vector(merged_bias_dim); | |
| ShapeHandle temp; | |
| TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp)); | |
| TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape)); | |
| } else { | |
| ShapeHandle all_but_bias; | |
| TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias)); | |
| DimensionHandle input_bias_dim = c->Dim(input_shape, -1); | |
| DimensionHandle merged_bias_dim; | |
| TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim)); | |
| ShapeHandle merged_bias = c->Vector(merged_bias_dim); | |
| TF_RETURN_IF_ERROR( | |
| c->Concatenate(all_but_bias, merged_bias, &output_shape)); | |
| } | |
| c->set_output(0, output_shape); | |
| return Status::OK(); | |
| } | |
| Status BiasAddGradShape(shape_inference::InferenceContext* c) { | |
| ShapeHandle input_shape; | |
| // Fetch the data_format attribute, which may not exist. | |
| string data_format; | |
| Status s = c->GetAttr("data_format", &data_format); | |
| if (s.ok() && data_format == "NCHW") { | |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); | |
| c->set_output(0, c->Vector(c->Dim(input_shape, -3))); | |
| } else { | |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); | |
| c->set_output(0, c->Vector(c->Dim(input_shape, -1))); | |
| } | |
| return Status::OK(); | |
| } | |
| Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, | |
| const ShapeHandle shape_handle, | |
| const string& tensor_name, | |
| shape_inference::InferenceContext* c) { | |
| if (tensor_format == FORMAT_NCHW_VECT_C) { | |
| // Check that the vect dim has size 4. | |
| const int num_dims = c->Rank(shape_handle); | |
| DimensionHandle vect_dim = c->Dim( | |
| shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format)); | |
| DimensionHandle unused_vect_dim; | |
| TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim)); | |
| } | |
| return Status::OK(); | |
| } | |
| Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, | |
| const std::vector<DimensionOrConstant>& spatial, | |
| DimensionOrConstant C, ShapeHandle* out, | |
| shape_inference::InferenceContext* context) { | |
| const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format); | |
| std::vector<DimensionHandle> dims_actual(num_dims); | |
| dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N); | |
| int outer_c_index = GetTensorFeatureDimIndex(num_dims, format); | |
| dims_actual[outer_c_index] = context->MakeDim(C); | |
| if (format == FORMAT_NCHW_VECT_C) { | |
| dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] = | |
| context->MakeDim(4); | |
| } | |
| for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) { | |
| dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] = | |
| context->MakeDim(spatial[spatial_dim]); | |
| } | |
| *out = context->MakeShape(dims_actual); | |
| return Status::OK(); | |
| } | |
| Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, | |
| DimensionHandle* batch_dim, | |
| gtl::MutableArraySlice<DimensionHandle> spatial_dims, | |
| DimensionHandle* filter_dim, | |
| InferenceContext* context) { | |
| const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format); | |
| // Batch. | |
| *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format)); | |
| // Spatial. | |
| for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size(); | |
| ++spatial_dim_index) { | |
| spatial_dims[spatial_dim_index] = context->Dim( | |
| shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index)); | |
| } | |
| // Channel. | |
| *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format)); | |
| if (format == FORMAT_NCHW_VECT_C) { | |
| TF_RETURN_IF_ERROR(context->Multiply( | |
| *filter_dim, | |
| context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)), | |
| filter_dim)); | |
| } | |
| return Status::OK(); | |
| } | |
| Status ShapeFromDimensions(DimensionHandle batch_dim, | |
| gtl::ArraySlice<DimensionHandle> spatial_dims, | |
| DimensionHandle filter_dim, TensorFormat format, | |
| InferenceContext* context, ShapeHandle* shape) { | |
| const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format); | |
| std::vector<DimensionHandle> out_dims(rank); | |
| // Batch. | |
| out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim; | |
| // Spatial. | |
| for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size(); | |
| ++spatial_dim_index) { | |
| out_dims[tensorflow::GetTensorSpatialDimIndex( | |
| rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index]; | |
| } | |
| // Channel. | |
| if (format == tensorflow::FORMAT_NCHW_VECT_C) { | |
| // When format is NCHW_VECT_C, factor the feature map count | |
| // into the outer feature count and the inner feature count (=4). | |
| TF_RETURN_IF_ERROR(context->Divide( | |
| filter_dim, 4, /*evenly_divisible=*/true, | |
| &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)])); | |
| out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4); | |
| } else { | |
| out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim; | |
| } | |
| *shape = context->MakeShape(out_dims); | |
| return tensorflow::Status::OK(); | |
| } | |
| Status Conv2DShape(shape_inference::InferenceContext* c) { | |
| string data_format_str, filter_format_str; | |
| if (!c->GetAttr("data_format", &data_format_str).ok()) { | |
| data_format_str = "NHWC"; | |
| } | |
| if (!c->GetAttr("filter_format", &filter_format_str).ok()) { | |
| filter_format_str = "HWIO"; | |
| } | |
| TensorFormat data_format; | |
| if (!FormatFromString(data_format_str, &data_format)) { | |
| return errors::InvalidArgument("Invalid data format string: ", | |
| data_format_str); | |
| } | |
| FilterTensorFormat filter_format; | |
| if (!FilterFormatFromString(filter_format_str, &filter_format)) { | |
| return errors::InvalidArgument("Invalid filter format string: ", | |
| filter_format_str); | |
| } | |
| constexpr int num_spatial_dims = 2; | |
| const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); | |
| ShapeHandle conv_input_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape)); | |
| TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape( | |
| data_format, conv_input_shape, "conv_input", c)); | |
| // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C). | |
| ShapeHandle filter_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); | |
| TF_RETURN_IF_ERROR( | |
| CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c)); | |
| std::vector<int32> dilations; | |
| TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); | |
| if (dilations.size() != 4) { | |
| return errors::InvalidArgument( | |
| "Conv2D requires the dilation attribute to contain 4 values, but got: ", | |
| dilations.size()); | |
| } | |
| std::vector<int32> strides; | |
| TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); | |
| // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C). | |
| if (strides.size() != 4) { | |
| return errors::InvalidArgument("Conv2D on data format ", data_format_str, | |
| " requires the stride attribute to contain" | |
| " 4 values, but got: ", | |
| strides.size()); | |
| } | |
| const int32 stride_rows = GetTensorDim(strides, data_format, 'H'); | |
| const int32 stride_cols = GetTensorDim(strides, data_format, 'W'); | |
| const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H'); | |
| const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W'); | |
| DimensionHandle batch_size_dim; | |
| DimensionHandle input_depth_dim; | |
| gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2); | |
| TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format, | |
| &batch_size_dim, &input_spatial_dims, | |
| &input_depth_dim, c)); | |
| DimensionHandle output_depth_dim = c->Dim( | |
| filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O')); | |
| DimensionHandle filter_rows_dim = c->Dim( | |
| filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H')); | |
| DimensionHandle filter_cols_dim = c->Dim( | |
| filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W')); | |
| DimensionHandle filter_input_depth_dim; | |
| if (filter_format == FORMAT_OIHW_VECT_I) { | |
| TF_RETURN_IF_ERROR(c->Multiply( | |
| c->Dim(filter_shape, | |
| GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')), | |
| c->Dim(filter_shape, | |
| GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)), | |
| &filter_input_depth_dim)); | |
| } else { | |
| filter_input_depth_dim = c->Dim( | |
| filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')); | |
| } | |
| // Check that the input tensor and the filter tensor agree on the input | |
| // channel count. | |
| DimensionHandle unused; | |
| TF_RETURN_IF_ERROR( | |
| c->Merge(input_depth_dim, filter_input_depth_dim, &unused)); | |
| Padding padding; | |
| TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); | |
| DimensionHandle output_rows, output_cols; | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( | |
| c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows, | |
| padding, &output_rows)); | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( | |
| c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols, | |
| padding, &output_cols)); | |
| ShapeHandle output_shape; | |
| TF_RETURN_IF_ERROR( | |
| ShapeFromDimensions(batch_size_dim, {output_rows, output_cols}, | |
| output_depth_dim, data_format, c, &output_shape)); | |
| c->set_output(0, output_shape); | |
| return Status::OK(); | |
| } | |
| // TODO(mjanusz): Unify all conv/pooling shape functions. | |
| Status Conv3DShape(shape_inference::InferenceContext* c) { | |
| ShapeHandle input_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); | |
| ShapeHandle filter_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape)); | |
| string data_format; | |
| Status s = c->GetAttr("data_format", &data_format); | |
| std::vector<int32> strides; | |
| TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); | |
| if (strides.size() != 5) { | |
| return errors::InvalidArgument( | |
| "Conv3D requires the stride attribute to contain 5 values, but got: ", | |
| strides.size()); | |
| } | |
| int32 stride_planes, stride_rows, stride_cols; | |
| if (s.ok() && data_format == "NCDHW") { | |
| // Convert input_shape to NDHWC. | |
| auto dim = [&](char dimension) { | |
| return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension)); | |
| }; | |
| input_shape = | |
| c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); | |
| stride_planes = strides[2]; | |
| stride_cols = strides[3]; | |
| stride_rows = strides[4]; | |
| } else { | |
| stride_planes = strides[1]; | |
| stride_rows = strides[2]; | |
| stride_cols = strides[3]; | |
| } | |
| DimensionHandle batch_size_dim = c->Dim(input_shape, 0); | |
| DimensionHandle in_planes_dim = c->Dim(input_shape, 1); | |
| DimensionHandle in_rows_dim = c->Dim(input_shape, 2); | |
| DimensionHandle in_cols_dim = c->Dim(input_shape, 3); | |
| DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0); | |
| DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1); | |
| DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2); | |
| DimensionHandle output_depth_dim = c->Dim(filter_shape, 4); | |
| DimensionHandle unused; | |
| TF_RETURN_IF_ERROR( | |
| c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused)); | |
| Padding padding; | |
| TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); | |
| DimensionHandle output_planes, output_rows, output_cols; | |
| TF_RETURN_IF_ERROR( | |
| GetWindowedOutputSizeFromDims(c, in_planes_dim, filter_planes_dim, | |
| stride_planes, padding, &output_planes)); | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows)); | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols)); | |
| ShapeHandle output_shape; | |
| if (data_format == "NCDHW") { | |
| output_shape = c->MakeShape({batch_size_dim, output_depth_dim, | |
| output_planes, output_rows, output_cols}); | |
| } else { | |
| output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, | |
| output_cols, output_depth_dim}); | |
| } | |
| c->set_output(0, output_shape); | |
| return Status::OK(); | |
| } | |
| Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { | |
| ShapeHandle input_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); | |
| ShapeHandle filter_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); | |
| std::vector<int32> strides; | |
| TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); | |
| if (strides.size() != 4) { | |
| return errors::InvalidArgument( | |
| "DepthwiseConv2D requires the stride attribute to contain 4 values, " | |
| "but got: ", | |
| strides.size()); | |
| } | |
| string data_format; | |
| Status s = c->GetAttr("data_format", &data_format); | |
| int32 stride_rows; | |
| int32 stride_cols; | |
| if (s.ok() && data_format == "NCHW") { | |
| // Canonicalize input shape to NHWC so the shape inference code below can | |
| // process it. | |
| input_shape = | |
| c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2), | |
| c->Dim(input_shape, 3), c->Dim(input_shape, 1)}}); | |
| stride_rows = strides[2]; | |
| stride_cols = strides[3]; | |
| } else { | |
| stride_rows = strides[1]; | |
| stride_cols = strides[2]; | |
| } | |
| DimensionHandle batch_size_dim = c->Dim(input_shape, 0); | |
| DimensionHandle in_rows_dim = c->Dim(input_shape, 1); | |
| DimensionHandle in_cols_dim = c->Dim(input_shape, 2); | |
| DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0); | |
| DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1); | |
| DimensionHandle input_depth = c->Dim(filter_shape, 2); | |
| DimensionHandle depth_multiplier = c->Dim(filter_shape, 3); | |
| // Check that the input depths are compatible. | |
| TF_RETURN_IF_ERROR( | |
| c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth)); | |
| DimensionHandle output_depth; | |
| TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth)); | |
| Padding padding; | |
| TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); | |
| // TODO(mrry,shlens): Raise an error if the stride would cause | |
| // information in the input to be ignored. This will require a change | |
| // in the kernel implementation. | |
| DimensionHandle output_rows, output_cols; | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows)); | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols)); | |
| ShapeHandle output_shape; | |
| if (data_format == "NCHW") { | |
| output_shape = | |
| c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols}); | |
| } else { | |
| output_shape = | |
| c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); | |
| } | |
| c->set_output(0, output_shape); | |
| return Status::OK(); | |
| } | |
| Status AvgPoolShape(shape_inference::InferenceContext* c) { | |
| string data_format_str; | |
| TensorFormat data_format; | |
| Status s = c->GetAttr("data_format", &data_format_str); | |
| if (s.ok()) { | |
| FormatFromString(data_format_str, &data_format); | |
| } else { | |
| data_format = FORMAT_NHWC; | |
| } | |
| const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; | |
| ShapeHandle input_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); | |
| TF_RETURN_IF_ERROR( | |
| CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); | |
| std::vector<int32> strides; | |
| TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); | |
| if (strides.size() != 4) { | |
| return errors::InvalidArgument( | |
| "AvgPool requires the stride attribute to contain 4 values, but got: ", | |
| strides.size()); | |
| } | |
| std::vector<int32> kernel_sizes; | |
| TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); | |
| if (kernel_sizes.size() != 4) { | |
| return errors::InvalidArgument( | |
| "AvgPool requires the ksize attribute to contain 4 values, but got: ", | |
| kernel_sizes.size()); | |
| } | |
| int32 stride_rows = GetTensorDim(strides, data_format, 'H'); | |
| int32 stride_cols = GetTensorDim(strides, data_format, 'W'); | |
| int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); | |
| int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); | |
| constexpr int num_spatial_dims = 2; | |
| DimensionHandle batch_size_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); | |
| DimensionHandle in_rows_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); | |
| DimensionHandle in_cols_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); | |
| DimensionHandle depth_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); | |
| Padding padding; | |
| TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); | |
| // TODO(mrry,shlens): Raise an error if the stride would cause | |
| // information in the input to be ignored. This will require a change | |
| // in the kernel implementation. | |
| DimensionHandle output_rows, output_cols; | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); | |
| ShapeHandle output_shape; | |
| TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, | |
| {output_rows, output_cols}, depth_dim, | |
| &output_shape, c)); | |
| c->set_output(0, output_shape); | |
| return Status::OK(); | |
| } | |
| Status FusedBatchNormShape(shape_inference::InferenceContext* c) { | |
| ShapeHandle x; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x)); | |
| bool is_training; | |
| TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); | |
| int number_inputs = (is_training) ? 3 : 5; | |
| string data_format; | |
| TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format)); | |
| DimensionHandle channel_dim = | |
| (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1); | |
| // covers scale, offset, and if is_training is false, mean, variance | |
| for (int i = 1; i < number_inputs; ++i) { | |
| ShapeHandle vec; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); | |
| TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); | |
| } | |
| ShapeHandle y; | |
| if (data_format == "NHWC") { | |
| TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y)); | |
| } else { | |
| TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y)); | |
| } | |
| c->set_output(0, y); | |
| ShapeHandle vector_shape = c->Vector(channel_dim); | |
| c->set_output(1, vector_shape); | |
| c->set_output(2, vector_shape); | |
| c->set_output(3, vector_shape); | |
| c->set_output(4, vector_shape); | |
| return Status::OK(); | |
| } | |
| Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { | |
| ShapeHandle y_backprop; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop)); | |
| ShapeHandle x; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x)); | |
| bool is_training; | |
| string data_format; | |
| TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); | |
| TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format)); | |
| DimensionHandle channel_dim = | |
| (data_format == "NHWC") ? c->Dim(y_backprop, 3) : c->Dim(y_backprop, 1); | |
| if (data_format == "NHWC") { | |
| TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim)); | |
| } else { | |
| TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim)); | |
| } | |
| // covers scale, mean (reserve_space_1), variance (reserve_space_2) | |
| for (int i = 2; i < 5; ++i) { | |
| ShapeHandle vec; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); | |
| TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); | |
| } | |
| ShapeHandle x_backprop; | |
| if (data_format == "NHWC") { | |
| TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop)); | |
| } else { | |
| TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop)); | |
| } | |
| c->set_output(0, x_backprop); | |
| c->set_output(1, c->Vector(channel_dim)); | |
| c->set_output(2, c->Vector(channel_dim)); | |
| // Set the correct shapes for reserve_spaces | |
| // so that gradients can be performed when | |
| // the op is in a symbolic condition. | |
| if (is_training) { | |
| c->set_output(3, c->Vector(0)); | |
| c->set_output(4, c->Vector(0)); | |
| } else { | |
| c->set_output(3, c->Vector(channel_dim)); | |
| c->set_output(4, c->Vector(channel_dim)); | |
| } | |
| return Status::OK(); | |
| } | |
| Status MaxPoolShape(shape_inference::InferenceContext* c) { | |
| string data_format_str; | |
| TensorFormat data_format; | |
| Status s = c->GetAttr("data_format", &data_format_str); | |
| if (s.ok()) { | |
| FormatFromString(data_format_str, &data_format); | |
| } else { | |
| data_format = FORMAT_NHWC; | |
| } | |
| const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; | |
| ShapeHandle input_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); | |
| TF_RETURN_IF_ERROR( | |
| CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); | |
| std::vector<int32> strides; | |
| TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); | |
| if (strides.size() != 4) { | |
| return errors::InvalidArgument( | |
| "MaxPool requires the stride attribute to contain 4 values, but got: ", | |
| strides.size()); | |
| } | |
| std::vector<int32> kernel_sizes; | |
| TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); | |
| if (kernel_sizes.size() != 4) { | |
| return errors::InvalidArgument( | |
| "MaxPool requires the ksize attribute to contain 4 values, but got: ", | |
| kernel_sizes.size()); | |
| } | |
| int32 stride_depth = GetTensorDim(strides, data_format, 'C'); | |
| int32 stride_rows = GetTensorDim(strides, data_format, 'H'); | |
| int32 stride_cols = GetTensorDim(strides, data_format, 'W'); | |
| int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); | |
| int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); | |
| int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); | |
| constexpr int num_spatial_dims = 2; | |
| DimensionHandle batch_size_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); | |
| DimensionHandle in_rows_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); | |
| DimensionHandle in_cols_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); | |
| DimensionHandle in_depth_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); | |
| Padding padding; | |
| TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); | |
| ShapeHandle output_shape; | |
| DimensionHandle output_rows, output_cols, output_depth; | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); | |
| TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, | |
| {output_rows, output_cols}, | |
| output_depth, &output_shape, c)); | |
| c->set_output(0, output_shape); | |
| return Status::OK(); | |
| } | |
| Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { | |
| string data_format_str; | |
| TensorFormat data_format; | |
| Status s = c->GetAttr("data_format", &data_format_str); | |
| if (s.ok()) { | |
| FormatFromString(data_format_str, &data_format); | |
| } else { | |
| data_format = FORMAT_NHWC; | |
| } | |
| const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; | |
| ShapeHandle input_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); | |
| TF_RETURN_IF_ERROR( | |
| CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); | |
| std::vector<int32> kernel_sizes; | |
| std::vector<int32> strides; | |
| if (c->num_inputs() + 2 == num_inputs) { | |
| TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); | |
| TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); | |
| } else { | |
| // Verify shape of ksize and strides input. | |
| ShapeHandle size; | |
| DimensionHandle unused; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size)); | |
| TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size)); | |
| TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); | |
| const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2); | |
| if (kernel_sizes_tensor == nullptr) { | |
| c->set_output(0, c->UnknownShape()); | |
| return Status::OK(); | |
| } | |
| kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); | |
| auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>(); | |
| std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), | |
| kernel_sizes.begin()); | |
| const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1); | |
| if (strides_tensor == nullptr) { | |
| c->set_output(0, c->UnknownShape()); | |
| return Status::OK(); | |
| } | |
| strides.resize(strides_tensor->shape().num_elements()); | |
| auto strides_vec = strides_tensor->flat<int32>(); | |
| std::copy_n(&strides_vec(0), strides.size(), strides.begin()); | |
| } | |
| if (strides.size() != 4) { | |
| return errors::InvalidArgument( | |
| "MaxPool requires the stride attribute to contain 4 values, but " | |
| "got: ", | |
| strides.size()); | |
| } | |
| if (kernel_sizes.size() != 4) { | |
| return errors::InvalidArgument( | |
| "MaxPool requires the ksize attribute to contain 4 values, but got: ", | |
| kernel_sizes.size()); | |
| } | |
| int32 stride_depth = GetTensorDim(strides, data_format, 'C'); | |
| int32 stride_rows = GetTensorDim(strides, data_format, 'H'); | |
| int32 stride_cols = GetTensorDim(strides, data_format, 'W'); | |
| int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); | |
| int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); | |
| int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); | |
| constexpr int num_spatial_dims = 2; | |
| DimensionHandle batch_size_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); | |
| DimensionHandle in_rows_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); | |
| DimensionHandle in_cols_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); | |
| DimensionHandle in_depth_dim = c->Dim( | |
| input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); | |
| Padding padding; | |
| TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); | |
| ShapeHandle output_shape; | |
| DimensionHandle output_rows, output_cols, output_depth; | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); | |
| TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, | |
| {output_rows, output_cols}, | |
| output_depth, &output_shape, c)); | |
| c->set_output(0, output_shape); | |
| return Status::OK(); | |
| } | |
| Status Pool3DShape(shape_inference::InferenceContext* c) { | |
| ShapeHandle input_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); | |
| string data_format; | |
| Status s = c->GetAttr("data_format", &data_format); | |
| std::vector<int32> strides; | |
| TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); | |
| if (strides.size() != 5) { | |
| return errors::InvalidArgument( | |
| "Pool3D ops require the stride attribute to contain 5 values, but " | |
| "got: ", | |
| strides.size()); | |
| } | |
| std::vector<int32> kernel_sizes; | |
| TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); | |
| if (kernel_sizes.size() != 5) { | |
| return errors::InvalidArgument( | |
| "Pool3D requires the ksize attribute to contain 5 values, but got: ", | |
| kernel_sizes.size()); | |
| } | |
| int32 stride_planes, stride_rows, stride_cols; | |
| int32 kernel_planes, kernel_rows, kernel_cols; | |
| if (s.ok() && data_format == "NCDHW") { | |
| // Convert input_shape to NDHWC. | |
| auto dim = [&](char dimension) { | |
| return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension)); | |
| }; | |
| input_shape = | |
| c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); | |
| stride_planes = strides[2]; | |
| stride_rows = strides[3]; | |
| stride_cols = strides[4]; | |
| kernel_planes = kernel_sizes[2]; | |
| kernel_rows = kernel_sizes[3]; | |
| kernel_cols = kernel_sizes[4]; | |
| } else { | |
| stride_planes = strides[1]; | |
| stride_rows = strides[2]; | |
| stride_cols = strides[3]; | |
| kernel_planes = kernel_sizes[1]; | |
| kernel_rows = kernel_sizes[2]; | |
| kernel_cols = kernel_sizes[3]; | |
| } | |
| DimensionHandle batch_size_dim = c->Dim(input_shape, 0); | |
| DimensionHandle in_planes_dim = c->Dim(input_shape, 1); | |
| DimensionHandle in_rows_dim = c->Dim(input_shape, 2); | |
| DimensionHandle in_cols_dim = c->Dim(input_shape, 3); | |
| DimensionHandle output_depth_dim = c->Dim(input_shape, 4); | |
| Padding padding; | |
| TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); | |
| // TODO(mrry,shlens): Raise an error if the stride would cause | |
| // information in the input to be ignored. This will require a change | |
| // in the kernel implementation. | |
| DimensionHandle output_planes, output_rows, output_cols; | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes)); | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); | |
| TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( | |
| c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); | |
| ShapeHandle output_shape; | |
| if (data_format == "NCDHW") { | |
| output_shape = c->MakeShape({batch_size_dim, output_depth_dim, | |
| output_planes, output_rows, output_cols}); | |
| } else { | |
| output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, | |
| output_cols, output_depth_dim}); | |
| } | |
| c->set_output(0, output_shape); | |
| return Status::OK(); | |
| } | |
| Status UnknownShape(shape_inference::InferenceContext* c) { | |
| for (int i = 0; i < c->num_outputs(); ++i) { | |
| c->set_output(i, c->UnknownShape()); | |
| } | |
| return Status::OK(); | |
| } | |
| template <typename T> | |
| Status ReductionShapeHelper(const Tensor* reduction_indices_t, | |
| const int32 input_rank, | |
| std::set<int64>& true_indices) { | |
| auto reduction_indices = reduction_indices_t->flat<T>(); | |
| for (int i = 0; i < reduction_indices_t->NumElements(); ++i) { | |
| const T reduction_index = reduction_indices(i); | |
| if (reduction_index < -input_rank || reduction_index >= input_rank) { | |
| return errors::InvalidArgument("Invalid reduction dimension ", | |
| reduction_index, " for input with ", | |
| input_rank, " dimensions."); | |
| } | |
| auto wrapped_index = reduction_index; | |
| if (wrapped_index < 0) { | |
| wrapped_index += input_rank; | |
| } | |
| true_indices.insert(wrapped_index); | |
| } | |
| return Status::OK(); | |
| } | |
| Status ReductionShape(InferenceContext* c) { | |
| ShapeHandle input = c->input(0); | |
| ShapeHandle indices; | |
| // Older versions of TensorFlow accidentally allowed higher rank tensors like | |
| // [[1,2]] or [[1],[2]] to represent axis=[1,2]. | |
| if (c->graph_def_version() < 21) { | |
| indices = c->input(1); | |
| } else { | |
| TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices)); | |
| } | |
| bool keep_dims; | |
| TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims)); | |
| const Tensor* reduction_indices_t = c->input_tensor(1); | |
| if (reduction_indices_t == nullptr || !c->RankKnown(input)) { | |
| // If we do not have the reduction values at runtime, or the | |
| // rank of the input, we don't know the output shape. | |
| if (keep_dims && c->RankKnown(input)) { | |
| // output rank matches input input if <keep_dims>. | |
| c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); | |
| return Status::OK(); | |
| } else { | |
| return shape_inference::UnknownShape(c); | |
| } | |
| } | |
| const int32 input_rank = c->Rank(input); | |
| std::set<int64> true_indices; | |
| if (reduction_indices_t->dtype() == DataType::DT_INT32) { | |
| TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t, | |
| input_rank, true_indices)); | |
| } else if (reduction_indices_t->dtype() == DataType::DT_INT64) { | |
| TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t, | |
| input_rank, true_indices)); | |
| } else { | |
| return errors::InvalidArgument( | |
| "reduction_indices can only be int32 or int64"); | |
| } | |
| std::vector<DimensionHandle> dims; | |
| for (int i = 0; i < input_rank; ++i) { | |
| if (true_indices.count(i) > 0) { | |
| if (keep_dims) { | |
| dims.emplace_back(c->MakeDim(1)); | |
| } | |
| } else { | |
| dims.emplace_back(c->Dim(input, i)); | |
| } | |
| } | |
| c->set_output(0, c->MakeShape(dims)); | |
| return Status::OK(); | |
| } | |
| Status ConcatShapeHelper(InferenceContext* c, int start_value_index, | |
| int end_value_index, int dim_index) { | |
| ShapeHandle unused; | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused)); | |
| const Tensor* concat_dim_t = c->input_tensor(dim_index); | |
| if (concat_dim_t == nullptr) { | |
| // Return an unknown shape with same rank as inputs, or an unknown rank | |
| // if no input's rank is known. | |
| // Find rank. | |
| int32 rank = InferenceContext::kUnknownRank; | |
| for (int i = start_value_index; i < end_value_index; ++i) { | |
| if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i)); | |
| if (rank != InferenceContext::kUnknownRank) { | |
| break; | |
| } | |
| } | |
| if (rank == InferenceContext::kUnknownRank) { | |
| c->set_output(0, c->UnknownShape()); | |
| return Status::OK(); | |
| } else if (rank == 0) { | |
| return errors::InvalidArgument( | |
| "Can't concatenate scalars (use tf.stack instead)"); | |
| } else { | |
| for (int i = start_value_index; i < end_value_index; ++i) { | |
| // Check that all the inputs are of the correct rank. | |
| TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused)); | |
| } | |
| } | |
| // Build result of <rank> different unknown dims. | |
| std::vector<DimensionHandle> dims; | |
| dims.reserve(rank); | |
| for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim()); | |
| c->set_output(0, c->MakeShape(dims)); | |
| return Status::OK(); | |
| } | |
| // Merge all the non-concat dims, and sum the concat dim to make an output | |
| // shape. | |
| const int32 concat_dim = concat_dim_t->scalar<int32>()(); | |
| // Minimum required number of dimensions. | |
| const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1; | |
| ShapeHandle output_before; | |
| ShapeHandle output_after; | |
| ShapeHandle input = c->input(end_value_index - 1); | |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); | |
| TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before)); | |
| DimensionHandle output_middle = c->Dim(input, concat_dim); | |
| if (concat_dim == -1) { | |
| output_after = c->Scalar(); // no dimensions. | |
| } else { | |
| TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after)); | |
| } | |
| for (int i = end_value_index - 2; i >= start_value_index; --i) { | |
| ShapeHandle before; | |
| ShapeHandle after; | |
| input = c->input(i); | |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); | |
| TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before)); | |
| DimensionHandle middle = c->Dim(input, concat_dim); | |
| if (concat_dim == -1) { | |
| after = c->Scalar(); | |
| } else { | |
| TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after)); | |
| } | |
| TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before)); | |
| TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle)); | |
| TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after)); | |
| } | |
| ShapeHandle s; | |
| TF_RETURN_IF_ERROR( | |
| c->Concatenate(output_before, c->Vector(output_middle), &s)); | |
| TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s)); | |
| c->set_output(0, s); | |
| return Status::OK(); | |
| } | |
| Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) { | |
| return ConcatShapeHelper(c, 1 /* start_value_index */, | |
| 1 + num_inputs_to_concat /* end_value_index */, | |
| 0 /* dim_index */); | |
| } | |
| Status ConcatV2Shape(InferenceContext* c) { | |
| return ConcatShapeHelper(c, 0 /* start_value_index */, | |
| c->num_inputs() - 1 /* end_value_index */, | |
| c->num_inputs() - 1 /* dim_index */); | |
| } | |
| Status BroadcastBinaryOpShapeFn(InferenceContext* c) { | |
| ShapeHandle shape_x = c->input(0); | |
| ShapeHandle shape_y = c->input(1); | |
| if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) { | |
| c->set_output(0, c->UnknownShape()); | |
| return Status::OK(); | |
| } | |
| const int32 rank_x = c->Rank(shape_x); | |
| const int32 rank_y = c->Rank(shape_y); | |
| const int32 rank_out = std::max(rank_x, rank_y); | |
| // To compute the broadcast dimensions, we zip together shape_x and shape_y | |
| // and | |
| // pad with 1 to make them the same length. | |
| std::vector<DimensionHandle> dims; | |
| DimensionHandle dim_one; | |
| if (rank_x != rank_y) dim_one = c->MakeDim(1); | |
| for (int i = 0; i < rank_out; ++i) { | |
| const auto dim_x = i < (rank_out - rank_x) | |
| ? dim_one | |
| : c->Dim(shape_x, i - (rank_out - rank_x)); | |
| const bool dim_y_is_one = (i < (rank_out - rank_y)); | |
| const auto dim_y = | |
| dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y)); | |
| if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) { | |
| // One or both dimensions is unknown. | |
| // | |
| // - If either dimension is greater than 1, we assume that the program is | |
| // correct, and the other dimension will be broadcast to match it. | |
| // TODO(cwhipkey): For shape inference, if we eliminate the shape checks | |
| // in C++ op code, we must still assert that the unknown dim is either 1 | |
| // or the same as the known dim. | |
| // - If either dimension is 1, the other dimension is the output. | |
| if (c->Value(dim_x) > 1) { | |
| dims.push_back(dim_x); | |
| } else if (c->Value(dim_y) > 1) { | |
| dims.push_back(dim_y); | |
| } else if (c->Value(dim_x) == 1) { | |
| dims.push_back(dim_y); | |
| } else if (c->Value(dim_y) == 1) { | |
| dims.push_back(dim_x); | |
| } else if (dim_y.SameHandle(dim_x)) { | |
| dims.push_back(dim_x); | |
| } else { | |
| dims.push_back(c->UnknownDim()); | |
| } | |
| } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) { | |
| if (c->Value(dim_x) == 1 && !dim_y_is_one) { | |
| // We will broadcast dim_x to dim_y. | |
| dims.push_back(dim_y); | |
| } else { | |
| DCHECK_EQ(c->Value(dim_y), 1); | |
| // We will broadcast dim_y to dim_x. | |
| dims.push_back(dim_x); | |
| } | |
| } else { | |
| DimensionHandle dim; | |
| TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim)); | |
| dims.push_back(dim); | |
| } | |
| } | |
| c->set_output(0, c->MakeShape(dims)); | |
| return Status::OK(); | |
| } | |
| Status RandomShape(shape_inference::InferenceContext* c) { | |
| shape_inference::ShapeHandle out; | |
| TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); | |
| c->set_output(0, out); | |
| return Status::OK(); | |
| } | |
| Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, | |
| ShapeHandle values_shape, ShapeHandle shape_shape) { | |
| // Validate ranks. | |
| ShapeHandle unused_shape; | |
| TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape)); | |
| TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape)); | |
| TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape)); | |
| // Number of elements in indices and values must match. | |
| DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0); | |
| if (c->ValueKnown(num_index_elements_dim)) { | |
| DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0); | |
| if (c->ValueKnown(num_values_elements_dim)) { | |
| int64 num_index_elements = c->Value(num_index_elements_dim); | |
| int64 num_values_elements = c->Value(num_values_elements_dim); | |
| if (num_index_elements != num_values_elements) { | |
| return errors::InvalidArgument("Number of elements in index (", | |
| num_index_elements, ") and values (", | |
| num_values_elements, ") do not match."); | |
| } | |
| } | |
| } | |
| // Rank embedded in indices must match shape. | |
| DimensionHandle index_rank_dim = c->Dim(indices_shape, 1); | |
| if (c->ValueKnown(index_rank_dim)) { | |
| DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0); | |
| if (c->ValueKnown(shape_rank_dim)) { | |
| int64 index_rank = c->Value(index_rank_dim); | |
| int32 shape_rank = c->Value(shape_rank_dim); | |
| if (index_rank != shape_rank) { | |
| return errors::InvalidArgument("Index rank (", index_rank, | |
| ") and shape rank (", shape_rank, | |
| ") do not match."); | |
| } | |
| } | |
| } | |
| return Status::OK(); | |
| } | |
| Status ScatterNdUpdateShape(InferenceContext* c) { | |
| ShapeHandle input_shape = c->input(0); | |
| if (c->input_handle_shapes_and_types(0) != nullptr) { | |
| input_shape = (*c->input_handle_shapes_and_types(0))[0].shape; | |
| } | |
| ShapeHandle indices_shape; | |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); | |
| ShapeHandle updates_shape; | |
| TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); | |
| if (c->Value(c->NumElements(input_shape)) == 0 && | |
| (c->Value(c->NumElements(indices_shape)) > 0 || | |
| c->Value(c->NumElements(updates_shape)) > 0)) { | |
| return errors::InvalidArgument( | |
| "Indices and updates specified for empty output shape"); | |
| } | |
| if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { | |
| const int64 num_outer_dims = c->Rank(indices_shape) - 1; | |
| const DimensionHandle index_size = c->Dim(indices_shape, -1); | |
| // We can only do more validation if the last dimension of indices | |
| // is a known value. | |
| if (c->ValueKnown(index_size)) { | |
| const int64 ix = c->Value(index_size); | |
| ShapeHandle unused; | |
| ShapeHandle prefix_indices; | |
| TF_RETURN_IF_ERROR( | |
| c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices)); | |
| ShapeHandle prefix_updates; | |
| TF_RETURN_IF_ERROR( | |
| c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates)); | |
| Status s = c->Merge(prefix_indices, prefix_updates, &unused); | |
| if (!s.ok()) { | |
| return errors::InvalidArgument( | |
| "The outer ", num_outer_dims, " dimensions of indices.shape=", | |
| c->DebugString(indices_shape), " must match the outer ", | |
| num_outer_dims, " dimensions of updates.shape=", | |
| c->DebugString(updates_shape), ": ", s.error_message()); | |
| } | |
| ShapeHandle input_suffix; | |
| TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix)); | |
| ShapeHandle suffix_updates; | |
| TF_RETURN_IF_ERROR( | |
| c->Subshape(updates_shape, num_outer_dims, &suffix_updates)); | |
| s = c->Merge(input_suffix, suffix_updates, &unused); | |
| if (!s.ok()) { | |
| return errors::InvalidArgument( | |
| "The inner ", c->Rank(input_shape) - ix, | |
| " dimensions of input.shape=", c->DebugString(input_shape), | |
| " must match the inner ", c->Rank(updates_shape) - num_outer_dims, | |
| " dimensions of updates.shape=", c->DebugString(updates_shape), | |
| ": ", s.error_message()); | |
| } | |
| } | |
| } | |
| if (c->input_handle_shapes_and_types(0) == nullptr) { | |
| c->set_output(0, input_shape); | |
| } | |
| return Status::OK(); | |
| } | |
| Status ExplicitShape(InferenceContext* c) { | |
| PartialTensorShape shape; | |
| TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); | |
| ShapeHandle output_shape; | |
| TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape)); | |
| c->set_output(0, output_shape); | |
| return Status::OK(); | |
| } | |
| } // namespace shape_inference | |
| } // namespace tensorflow | |