fomo-face-detection
/
ei-cpp-export
/edge-impulse-sdk
/tensorflow
/lite
/micro
/kernels
/logistic.cc
| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| ==============================================================================*/ | |
| namespace tflite { | |
| namespace ops { | |
| namespace micro { | |
| namespace activations { | |
| namespace { | |
| constexpr int kInputTensor = 0; | |
| constexpr int kOutputTensor = 0; | |
| struct OpData { | |
| int32_t input_zero_point; | |
| int32_t input_range_radius; | |
| int32_t input_multiplier; | |
| int input_left_shift; | |
| }; | |
| TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node, | |
| OpData* data) { | |
| const TfLiteTensor* input = GetInput(context, node, kInputTensor); | |
| TF_LITE_ENSURE(context, input != nullptr); | |
| TfLiteTensor* output = GetOutput(context, node, kOutputTensor); | |
| TF_LITE_ENSURE(context, output != nullptr); | |
| TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); | |
| if (input->type == kTfLiteInt8) { | |
| TF_LITE_ENSURE_EQ(context, output->params.zero_point, | |
| std::numeric_limits<int8_t>::min()); | |
| static constexpr int kInputIntegerBits = 4; | |
| const double input_real_multiplier = | |
| static_cast<double>(input->params.scale) * | |
| static_cast<double>(1 << (31 - kInputIntegerBits)); | |
| data->input_zero_point = input->params.zero_point; | |
| const double q = std::frexp(input_real_multiplier, &data->input_left_shift); | |
| data->input_multiplier = static_cast<int32_t>(TfLiteRound(q * (1ll << 31))); | |
| data->input_range_radius = | |
| CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31); | |
| } | |
| return kTfLiteOk; | |
| } | |
| } // namespace | |
| void* LogisticInit(TfLiteContext* context, const char* buffer, size_t length) { | |
| TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); | |
| return context->AllocatePersistentBuffer(context, sizeof(OpData)); | |
| } | |
| TfLiteStatus LogisticPrepare(TfLiteContext* context, TfLiteNode* node) { | |
| TFLITE_DCHECK(node->user_data != nullptr); | |
| OpData* data = static_cast<OpData*>(node->user_data); | |
| return CalculateArithmeticOpData(context, node, data); | |
| } | |
| TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) { | |
| const TfLiteEvalTensor* input = | |
| tflite::micro::GetEvalInput(context, node, kInputTensor); | |
| TfLiteEvalTensor* output = | |
| tflite::micro::GetEvalOutput(context, node, kOutputTensor); | |
| TFLITE_DCHECK(node->user_data != nullptr); | |
| OpData* data = static_cast<OpData*>(node->user_data); | |
| if (input->type == kTfLiteFloat32) { | |
| switch (output->type) { | |
| case kTfLiteFloat32: { | |
| reference_ops::Logistic(tflite::micro::GetTensorShape(input), | |
| tflite::micro::GetTensorData<float>(input), | |
| tflite::micro::GetTensorShape(output), | |
| tflite::micro::GetTensorData<float>(output)); | |
| return kTfLiteOk; | |
| } | |
| default: | |
| TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", | |
| TfLiteTypeGetName(input->type), | |
| TfLiteTypeGetName(output->type)); | |
| return kTfLiteError; | |
| } | |
| } else if (input->type == kTfLiteInt8) { | |
| switch (output->type) { | |
| case kTfLiteInt8: { | |
| reference_integer_ops::Logistic( | |
| data->input_zero_point, data->input_range_radius, | |
| data->input_multiplier, data->input_left_shift, | |
| NumElements(input->dims), | |
| tflite::micro::GetTensorData<int8_t>(input), | |
| tflite::micro::GetTensorData<int8_t>(output)); | |
| return kTfLiteOk; | |
| } | |
| default: | |
| TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", | |
| TfLiteTypeGetName(input->type), | |
| TfLiteTypeGetName(output->type)); | |
| return kTfLiteError; | |
| } | |
| } else { | |
| // TODO(b/141211002): Also support other data types once we have supported | |
| // temporary tensors in TFLM. | |
| TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", | |
| TfLiteTypeGetName(input->type), | |
| TfLiteTypeGetName(output->type)); | |
| return kTfLiteError; | |
| } | |
| return kTfLiteOk; | |
| } | |
| } // namespace activations | |
| TfLiteRegistration Register_LOGISTIC() { | |
| return {/*init=*/activations::LogisticInit, | |
| /*free=*/nullptr, | |
| /*prepare=*/activations::LogisticPrepare, | |
| /*invoke=*/activations::LogisticEval, | |
| /*profiling_string=*/nullptr, | |
| /*builtin_code=*/0, | |
| /*custom_name=*/nullptr, | |
| /*version=*/0}; | |
| } | |
| } // namespace micro | |
| } // namespace ops | |
| } // namespace tflite | |