// Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #ifndef CHROMEOS_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_FAKE_SERVICE_CONNECTION_H_ #define CHROMEOS_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_FAKE_SERVICE_CONNECTION_H_ #include #include #include "base/callback_forward.h" #include "base/component_export.h" #include "chromeos/services/machine_learning/public/cpp/service_connection.h" #include "chromeos/services/machine_learning/public/mojom/document_scanner.mojom.h" #include "chromeos/services/machine_learning/public/mojom/grammar_checker.mojom.h" #include "chromeos/services/machine_learning/public/mojom/graph_executor.mojom.h" #include "chromeos/services/machine_learning/public/mojom/handwriting_recognizer.mojom.h" #include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h" #include "chromeos/services/machine_learning/public/mojom/model.mojom.h" #include "chromeos/services/machine_learning/public/mojom/soda.mojom.h" #include "chromeos/services/machine_learning/public/mojom/tensor.mojom.h" #include "chromeos/services/machine_learning/public/mojom/text_classifier.mojom.h" #include "chromeos/services/machine_learning/public/mojom/text_suggester.mojom.h" #include "chromeos/services/machine_learning/public/mojom/web_platform_handwriting.mojom.h" #include "mojo/public/cpp/bindings/pending_receiver.h" #include "mojo/public/cpp/bindings/receiver_set.h" #include "mojo/public/cpp/bindings/remote_set.h" namespace chromeos { namespace machine_learning { // Fake implementation of chromeos::machine_learning::ServiceConnection. // Handles LoadModel (and Model::CreateGraphExecutor) by binding to itself. // Handles GraphExecutor::Execute by always returning the value specified by // a previous call to SetOutputValue. // Handles TextClassifier::Annotate by always returning the value specified by // a previous call to SetOutputAnnotation. // For use with ServiceConnection::UseFakeServiceConnectionForTesting(). class COMPONENT_EXPORT(CHROMEOS_MLSERVICE) FakeServiceConnectionImpl : public ServiceConnection, public mojom::MachineLearningService, public mojom::Model, public mojom::TextClassifier, public mojom::HandwritingRecognizer, public mojom::GrammarChecker, public mojom::GraphExecutor, public mojom::SodaRecognizer, public mojom::TextSuggester, public mojom::DocumentScanner, public web_platform::mojom::HandwritingRecognizer { public: FakeServiceConnectionImpl(); FakeServiceConnectionImpl(const FakeServiceConnectionImpl&) = delete; FakeServiceConnectionImpl& operator=(const FakeServiceConnectionImpl&) = delete; ~FakeServiceConnectionImpl() override; // ServiceConnection: mojom::MachineLearningService& GetMachineLearningService() override; void BindMachineLearningService( mojo::PendingReceiver receiver) override; void Initialize() override; // mojom::MachineLearningService: void Clone( mojo::PendingReceiver receiver) override; // It's safe to execute LoadBuiltinModel, LoadFlatBufferModel and // LoadTextClassifier for multi times, but all the receivers will be bound to // the same instance. void LoadBuiltinModel(mojom::BuiltinModelSpecPtr spec, mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadBuiltinModelCallback callback) override; void LoadFlatBufferModel( mojom::FlatBufferModelSpecPtr spec, mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadFlatBufferModelCallback callback) override; void LoadTextClassifier( mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadTextClassifierCallback callback) override; void LoadHandwritingModel( mojom::HandwritingRecognizerSpecPtr spec, mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadHandwritingModelCallback result_callback) override; // Dedicated HWR API for Web Platform. void LoadWebPlatformHandwritingModel( web_platform::mojom::HandwritingModelConstraintPtr constraint, mojo::PendingReceiver receiver, LoadWebPlatformHandwritingModelCallback callback) override; void LoadGrammarChecker( mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadGrammarCheckerCallback callback) override; void LoadSpeechRecognizer( mojom::SodaConfigPtr soda_config, mojo::PendingRemote soda_client, mojo::PendingReceiver soda_recognizer, mojom::MachineLearningService::LoadSpeechRecognizerCallback callback) override; void LoadTextSuggester( mojo::PendingReceiver receiver, mojom::TextSuggesterSpecPtr spec, mojom::MachineLearningService::LoadTextSuggesterCallback callback) override; void LoadDocumentScanner( mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadDocumentScannerCallback callback) override; // mojom::Model: void REMOVED_0(mojo::PendingReceiver receiver, mojom::Model::REMOVED_0Callback callback) override; // mojom::Model: void REMOVED_4(mojom::HandwritingRecognizerSpecPtr spec, mojo::PendingReceiver receiver, mojom::MachineLearningService::REMOVED_4Callback result_callback) override; // mojom::Model: void CreateGraphExecutor( mojom::GraphExecutorOptionsPtr options, mojo::PendingReceiver receiver, mojom::Model::CreateGraphExecutorCallback callback) override; // mojom::GraphExecutor: // Execute() will return the tensor set by SetOutputValue() as the output. void Execute(base::flat_map inputs, const std::vector& output_names, mojom::GraphExecutor::ExecuteCallback callback) override; // Useful for simulating a failure at different stage. // There are different error codes at each stage, we just randomly pick one. void SetLoadModelFailure(); void SetCreateGraphExecutorFailure(); void SetExecuteFailure(); void SetLoadTextClassifierFailure(); // Reset all the Model related failures and make Execute succeed. void SetExecuteSuccess(); // Reset all the TextClassifier related failures and make LoadTextClassifier // succeed. // Currently, there are two interfaces related to TextClassifier // (|LoadTextClassifier|, |Annotate|) but only // |LoadTextClassifier| can fail. void SetTextClassifierSuccess(); // Call SetOutputValue() before Execute() to set the output tensor. void SetOutputValue(const std::vector& shape, const std::vector& value); // In async mode, FakeServiceConnectionImpl adds requests like // LoadBuiltinModel, CreateGraphExecutor to |pending_calls_| instead of // responding immediately. Calls in |pending_calls_| will run when // RunPendingCalls() is called. // It's useful when an unit test wants to test the async behaviour of real // ml-service. void SetAsyncMode(bool async_mode); void RunPendingCalls(); // Call SetOutputAnnotation() before Annotate() to set the output annotation. void SetOutputAnnotation( const std::vector& annotation); // Call SetOutputLanguages() before FindLanguages() to set the output // languages. void SetOutputLanguages(const std::vector& languages); // Call SetOutputGrammarCheckerResult() before Check() to set the output of // grammar checker. void SetOutputGrammarCheckerResult( const mojom::GrammarCheckerResultPtr& result); // Call SetOutputHandwritingRecognizerResult() before Recognize() to set the // output of handwriting. void SetOutputHandwritingRecognizerResult( const mojom::HandwritingRecognizerResultPtr& result); // Call SetOutputWebPlatformHandwritingRecognizerResult() before // GetPrediction() to set the output of handwriting. void SetOutputWebPlatformHandwritingRecognizerResult( const std::vector& predictions); // Call SetOutputTextSuggesterResult() before Suggest() to set the // output of a text suggestion query. void SetOutputTextSuggesterResult( const mojom::TextSuggesterResultPtr& result); // Call SetOutputDetectCornersResult() before // DetectCornersFrom{NV12/JPEG}Image() to set the output of corners detection. void SetOutputDetectCornersResult( const mojom::DetectCornersResultPtr& result); // Call SetOutputDoPostProcessingResult() before DoPostProcessing() to set the // output of document post processing. void SetOutputDoPostProcessingResult( const mojom::DoPostProcessingResultPtr& result); // mojom::TextClassifier: void Annotate(mojom::TextAnnotationRequestPtr request, mojom::TextClassifier::AnnotateCallback callback) override; // mojom::TextClassifier: void FindLanguages( const std::string& text, mojom::TextClassifier::FindLanguagesCallback callback) override; // mojom::TextClassifier: void REMOVED_1( mojom::REMOVED_TextSuggestSelectionRequestPtr request, mojom::TextClassifier::REMOVED_1Callback callback) override; // mojom::HandwritingRecognizer: void Recognize( mojom::HandwritingRecognitionQueryPtr query, mojom::HandwritingRecognizer::RecognizeCallback callback) override; // web_platform::mojom::HandwritingRecognizer void GetPrediction( std::vector strokes, web_platform::mojom::HandwritingHintsPtr hints, web_platform::mojom::HandwritingRecognizer::GetPredictionCallback callback) override; // mojom::GrammarChecker: void Check(mojom::GrammarCheckerQueryPtr query, mojom::GrammarChecker::CheckCallback callback) override; // mojom::SpeechRecognizer void AddAudio(const std::vector& audio) override; void Stop() override; void Start() override; void MarkDone() override; // mojom::TextSuggester: void Suggest(mojom::TextSuggesterQueryPtr query, mojom::TextSuggester::SuggestCallback callback) override; // mojom::DocumentScanner: void DetectCornersFromNV12Image( base::ReadOnlySharedMemoryRegion nv12_image, mojom::DocumentScanner::DetectCornersFromNV12ImageCallback callback) override; void DetectCornersFromJPEGImage( base::ReadOnlySharedMemoryRegion jpeg_image, mojom::DocumentScanner::DetectCornersFromJPEGImageCallback callback) override; void DoPostProcessing( base::ReadOnlySharedMemoryRegion jpeg_image, const std::vector& corners, chromeos::machine_learning::mojom::Rotation rotation, mojom::DocumentScanner::DoPostProcessingCallback callback) override; // Flush all relevant Mojo pipes. void FlushForTesting(); private: void ScheduleCall(base::OnceClosure call); void HandleLoadBuiltinModelCall( mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadBuiltinModelCallback callback); void HandleLoadFlatBufferModelCall( mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadFlatBufferModelCallback callback); void HandleCreateGraphExecutorCall( mojom::GraphExecutorOptionsPtr options, mojo::PendingReceiver receiver, mojom::Model::CreateGraphExecutorCallback callback); void HandleExecuteCall(mojom::GraphExecutor::ExecuteCallback callback); void HandleLoadTextClassifierCall( mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadTextClassifierCallback callback); void HandleAnnotateCall(mojom::TextAnnotationRequestPtr request, mojom::TextClassifier::AnnotateCallback callback); void HandleFindLanguagesCall( std::string text, mojom::TextClassifier::FindLanguagesCallback callback); void HandleLoadHandwritingModelCall( mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadHandwritingModelCallback callback); void HandleLoadWebPlatformHandwritingModelCall( mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadHandwritingModelCallback callback); void HandleRecognizeCall( mojom::HandwritingRecognitionQueryPtr query, mojom::HandwritingRecognizer::RecognizeCallback callback); void HandleGetPredictionCall( std::vector strokes, web_platform::mojom::HandwritingHintsPtr hints, web_platform::mojom::HandwritingRecognizer::GetPredictionCallback callback); void HandleLoadGrammarCheckerCall( mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadGrammarCheckerCallback callback); void HandleGrammarCheckerQueryCall( mojom::GrammarCheckerQueryPtr query, mojom::GrammarChecker::CheckCallback callback); void HandleLoadSpeechRecognizerCall( mojo::PendingRemote soda_client, mojo::PendingReceiver soda_recognizer, mojom::MachineLearningService::LoadSpeechRecognizerCallback callback); void HandleLoadTextSuggesterCall( mojo::PendingReceiver receiver, mojom::TextSuggesterSpecPtr spec, mojom::MachineLearningService::LoadTextSuggesterCallback callback); void HandleTextSuggesterSuggestCall( mojom::TextSuggesterQueryPtr query, mojom::TextSuggester::SuggestCallback callback); void HandleLoadDocumentScannerCall( mojo::PendingReceiver receiver, mojom::MachineLearningService::LoadDocumentScannerCallback callback); void HandleDocumentScannerDetectNV12Call( base::ReadOnlySharedMemoryRegion nv12_image, mojom::DocumentScanner::DetectCornersFromNV12ImageCallback callback); void HandleDocumentScannerDetectJPEGCall( base::ReadOnlySharedMemoryRegion jpeg_image, mojom::DocumentScanner::DetectCornersFromJPEGImageCallback callback); void HandleDocumentScannerPostProcessingCall( base::ReadOnlySharedMemoryRegion jpeg_image, const std::vector& corners, mojom::DocumentScanner::DoPostProcessingCallback callback); void HandleStopCall(); void HandleStartCall(); void HandleMarkDoneCall(); // Additional receivers bound via `Clone`. mojo::ReceiverSet clone_ml_service_receivers_; mojo::Remote machine_learning_service_; mojo::ReceiverSet model_receivers_; mojo::ReceiverSet graph_receivers_; mojo::ReceiverSet text_classifier_receivers_; mojo::ReceiverSet handwriting_receivers_; mojo::ReceiverSet web_platform_handwriting_receivers_; mojo::ReceiverSet grammar_checker_receivers_; mojo::ReceiverSet soda_recognizer_receivers_; mojo::ReceiverSet text_suggester_receivers_; mojo::ReceiverSet document_scanner_receivers_; mojo::RemoteSet soda_client_remotes_; mojom::TensorPtr output_tensor_; mojom::LoadHandwritingModelResult load_handwriting_model_result_; mojom::LoadHandwritingModelResult load_web_platform_handwriting_model_result_; mojom::LoadModelResult load_model_result_; mojom::LoadModelResult load_text_classifier_result_; mojom::LoadModelResult load_soda_result_; mojom::CreateGraphExecutorResult create_graph_executor_result_; mojom::ExecuteResult execute_result_; std::vector annotate_result_; mojom::CodepointSpanPtr suggest_selection_result_; std::vector find_languages_result_; mojom::HandwritingRecognizerResultPtr handwriting_result_; std::vector web_platform_handwriting_result_; mojom::GrammarCheckerResultPtr grammar_checker_result_; mojom::TextSuggesterResultPtr text_suggester_result_; mojom::DetectCornersResultPtr detect_corners_result_; mojom::DoPostProcessingResultPtr do_post_processing_result_; bool async_mode_; std::vector pending_calls_; }; } // namespace machine_learning } // namespace chromeos #endif // CHROMEOS_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_FAKE_SERVICE_CONNECTION_H_