File size: 17,152 Bytes
f8dc3a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef CHROMEOS_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_FAKE_SERVICE_CONNECTION_H_
#define CHROMEOS_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_FAKE_SERVICE_CONNECTION_H_

#include <memory>
#include <vector>

#include "base/callback_forward.h"
#include "base/component_export.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/document_scanner.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/grammar_checker.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/graph_executor.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/handwriting_recognizer.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/model.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/soda.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/tensor.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/text_classifier.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/text_suggester.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/web_platform_handwriting.mojom.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver_set.h"
#include "mojo/public/cpp/bindings/remote_set.h"

namespace chromeos {
namespace machine_learning {

// Fake implementation of chromeos::machine_learning::ServiceConnection.
// Handles LoadModel (and Model::CreateGraphExecutor) by binding to itself.
// Handles GraphExecutor::Execute by always returning the value specified by
// a previous call to SetOutputValue.
// Handles TextClassifier::Annotate by always returning the value specified by
// a previous call to SetOutputAnnotation.
// For use with ServiceConnection::UseFakeServiceConnectionForTesting().
class COMPONENT_EXPORT(CHROMEOS_MLSERVICE) FakeServiceConnectionImpl
    : public ServiceConnection,
      public mojom::MachineLearningService,
      public mojom::Model,
      public mojom::TextClassifier,
      public mojom::HandwritingRecognizer,
      public mojom::GrammarChecker,
      public mojom::GraphExecutor,
      public mojom::SodaRecognizer,
      public mojom::TextSuggester,
      public mojom::DocumentScanner,
      public web_platform::mojom::HandwritingRecognizer {
 public:
  FakeServiceConnectionImpl();

  FakeServiceConnectionImpl(const FakeServiceConnectionImpl&) = delete;
  FakeServiceConnectionImpl& operator=(const FakeServiceConnectionImpl&) =
      delete;

  ~FakeServiceConnectionImpl() override;

  // ServiceConnection:
  mojom::MachineLearningService& GetMachineLearningService() override;
  void BindMachineLearningService(
      mojo::PendingReceiver<mojom::MachineLearningService> receiver) override;
  void Initialize() override;

  // mojom::MachineLearningService:
  void Clone(
      mojo::PendingReceiver<mojom::MachineLearningService> receiver) override;

  // It's safe to execute LoadBuiltinModel, LoadFlatBufferModel and
  // LoadTextClassifier for multi times, but all the receivers will be bound to
  // the same instance.
  void LoadBuiltinModel(mojom::BuiltinModelSpecPtr spec,
                        mojo::PendingReceiver<mojom::Model> receiver,
                        mojom::MachineLearningService::LoadBuiltinModelCallback
                            callback) override;
  void LoadFlatBufferModel(
      mojom::FlatBufferModelSpecPtr spec,
      mojo::PendingReceiver<mojom::Model> receiver,
      mojom::MachineLearningService::LoadFlatBufferModelCallback callback)
      override;

  void LoadTextClassifier(
      mojo::PendingReceiver<mojom::TextClassifier> receiver,
      mojom::MachineLearningService::LoadTextClassifierCallback callback)
      override;

  void LoadHandwritingModel(
      mojom::HandwritingRecognizerSpecPtr spec,
      mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
      mojom::MachineLearningService::LoadHandwritingModelCallback
          result_callback) override;

  // Dedicated HWR API for Web Platform.
  void LoadWebPlatformHandwritingModel(
      web_platform::mojom::HandwritingModelConstraintPtr constraint,
      mojo::PendingReceiver<web_platform::mojom::HandwritingRecognizer>
          receiver,
      LoadWebPlatformHandwritingModelCallback callback) override;

  void LoadGrammarChecker(
      mojo::PendingReceiver<mojom::GrammarChecker> receiver,
      mojom::MachineLearningService::LoadGrammarCheckerCallback callback)
      override;

  void LoadSpeechRecognizer(
      mojom::SodaConfigPtr soda_config,
      mojo::PendingRemote<mojom::SodaClient> soda_client,
      mojo::PendingReceiver<mojom::SodaRecognizer> soda_recognizer,
      mojom::MachineLearningService::LoadSpeechRecognizerCallback callback)
      override;

  void LoadTextSuggester(
      mojo::PendingReceiver<mojom::TextSuggester> receiver,
      mojom::TextSuggesterSpecPtr spec,
      mojom::MachineLearningService::LoadTextSuggesterCallback callback)
      override;

  void LoadDocumentScanner(
      mojo::PendingReceiver<mojom::DocumentScanner> receiver,
      mojom::MachineLearningService::LoadDocumentScannerCallback callback)
      override;

  // mojom::Model:
  void REMOVED_0(mojo::PendingReceiver<mojom::GraphExecutor> receiver,
                 mojom::Model::REMOVED_0Callback callback) override;

  // mojom::Model:
  void REMOVED_4(mojom::HandwritingRecognizerSpecPtr spec,
                 mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
                 mojom::MachineLearningService::REMOVED_4Callback
                     result_callback) override;

  // mojom::Model:
  void CreateGraphExecutor(
      mojom::GraphExecutorOptionsPtr options,
      mojo::PendingReceiver<mojom::GraphExecutor> receiver,
      mojom::Model::CreateGraphExecutorCallback callback) override;

  // mojom::GraphExecutor:
  // Execute() will return the tensor set by SetOutputValue() as the output.
  void Execute(base::flat_map<std::string, mojom::TensorPtr> inputs,
               const std::vector<std::string>& output_names,
               mojom::GraphExecutor::ExecuteCallback callback) override;

  // Useful for simulating a failure at different stage.
  // There are different error codes at each stage, we just randomly pick one.
  void SetLoadModelFailure();
  void SetCreateGraphExecutorFailure();
  void SetExecuteFailure();
  void SetLoadTextClassifierFailure();
  // Reset all the Model related failures and make Execute succeed.
  void SetExecuteSuccess();
  // Reset all the TextClassifier related failures and make LoadTextClassifier
  // succeed.
  // Currently, there are two interfaces related to TextClassifier
  // (|LoadTextClassifier|, |Annotate|) but only
  // |LoadTextClassifier| can fail.
  void SetTextClassifierSuccess();

  // Call SetOutputValue() before Execute() to set the output tensor.
  void SetOutputValue(const std::vector<int64_t>& shape,
                      const std::vector<double>& value);

  // In async mode, FakeServiceConnectionImpl adds requests like
  // LoadBuiltinModel, CreateGraphExecutor to |pending_calls_| instead of
  // responding immediately. Calls in |pending_calls_| will run when
  // RunPendingCalls() is called.
  // It's useful when an unit test wants to test the async behaviour of real
  // ml-service.
  void SetAsyncMode(bool async_mode);
  void RunPendingCalls();

  // Call SetOutputAnnotation() before Annotate() to set the output annotation.
  void SetOutputAnnotation(
      const std::vector<mojom::TextAnnotationPtr>& annotation);

  // Call SetOutputLanguages() before FindLanguages() to set the output
  // languages.
  void SetOutputLanguages(const std::vector<mojom::TextLanguagePtr>& languages);

  // Call SetOutputGrammarCheckerResult() before Check() to set the output of
  // grammar checker.
  void SetOutputGrammarCheckerResult(
      const mojom::GrammarCheckerResultPtr& result);

  // Call SetOutputHandwritingRecognizerResult() before Recognize() to set the
  // output of handwriting.
  void SetOutputHandwritingRecognizerResult(
      const mojom::HandwritingRecognizerResultPtr& result);

  // Call SetOutputWebPlatformHandwritingRecognizerResult() before
  // GetPrediction() to set the output of handwriting.
  void SetOutputWebPlatformHandwritingRecognizerResult(
      const std::vector<web_platform::mojom::HandwritingPredictionPtr>&
          predictions);

  // Call SetOutputTextSuggesterResult() before Suggest() to set the
  // output of a text suggestion query.
  void SetOutputTextSuggesterResult(
      const mojom::TextSuggesterResultPtr& result);

  // Call SetOutputDetectCornersResult() before
  // DetectCornersFrom{NV12/JPEG}Image() to set the output of corners detection.
  void SetOutputDetectCornersResult(
      const mojom::DetectCornersResultPtr& result);

  // Call SetOutputDoPostProcessingResult() before DoPostProcessing() to set the
  // output of document post processing.
  void SetOutputDoPostProcessingResult(
      const mojom::DoPostProcessingResultPtr& result);

  // mojom::TextClassifier:
  void Annotate(mojom::TextAnnotationRequestPtr request,
                mojom::TextClassifier::AnnotateCallback callback) override;

  // mojom::TextClassifier:
  void FindLanguages(
      const std::string& text,
      mojom::TextClassifier::FindLanguagesCallback callback) override;

  // mojom::TextClassifier:
  void REMOVED_1(
      mojom::REMOVED_TextSuggestSelectionRequestPtr request,
      mojom::TextClassifier::REMOVED_1Callback callback) override;

  // mojom::HandwritingRecognizer:
  void Recognize(
      mojom::HandwritingRecognitionQueryPtr query,
      mojom::HandwritingRecognizer::RecognizeCallback callback) override;

  // web_platform::mojom::HandwritingRecognizer
  void GetPrediction(
      std::vector<web_platform::mojom::HandwritingStrokePtr> strokes,
      web_platform::mojom::HandwritingHintsPtr hints,
      web_platform::mojom::HandwritingRecognizer::GetPredictionCallback
          callback) override;

  // mojom::GrammarChecker:
  void Check(mojom::GrammarCheckerQueryPtr query,
             mojom::GrammarChecker::CheckCallback callback) override;

  // mojom::SpeechRecognizer
  void AddAudio(const std::vector<uint8_t>& audio) override;
  void Stop() override;
  void Start() override;
  void MarkDone() override;

  // mojom::TextSuggester:
  void Suggest(mojom::TextSuggesterQueryPtr query,
               mojom::TextSuggester::SuggestCallback callback) override;

  // mojom::DocumentScanner:
  void DetectCornersFromNV12Image(
      base::ReadOnlySharedMemoryRegion nv12_image,
      mojom::DocumentScanner::DetectCornersFromNV12ImageCallback callback)
      override;
  void DetectCornersFromJPEGImage(
      base::ReadOnlySharedMemoryRegion jpeg_image,
      mojom::DocumentScanner::DetectCornersFromJPEGImageCallback callback)
      override;
  void DoPostProcessing(
      base::ReadOnlySharedMemoryRegion jpeg_image,
      const std::vector<gfx::PointF>& corners,
      chromeos::machine_learning::mojom::Rotation rotation,
      mojom::DocumentScanner::DoPostProcessingCallback callback) override;

  // Flush all relevant Mojo pipes.
  void FlushForTesting();

 private:
  void ScheduleCall(base::OnceClosure call);
  void HandleLoadBuiltinModelCall(
      mojo::PendingReceiver<mojom::Model> receiver,
      mojom::MachineLearningService::LoadBuiltinModelCallback callback);
  void HandleLoadFlatBufferModelCall(
      mojo::PendingReceiver<mojom::Model> receiver,
      mojom::MachineLearningService::LoadFlatBufferModelCallback callback);
  void HandleCreateGraphExecutorCall(
      mojom::GraphExecutorOptionsPtr options,
      mojo::PendingReceiver<mojom::GraphExecutor> receiver,
      mojom::Model::CreateGraphExecutorCallback callback);
  void HandleExecuteCall(mojom::GraphExecutor::ExecuteCallback callback);
  void HandleLoadTextClassifierCall(
      mojo::PendingReceiver<mojom::TextClassifier> receiver,
      mojom::MachineLearningService::LoadTextClassifierCallback callback);
  void HandleAnnotateCall(mojom::TextAnnotationRequestPtr request,
                          mojom::TextClassifier::AnnotateCallback callback);
  void HandleFindLanguagesCall(
      std::string text,
      mojom::TextClassifier::FindLanguagesCallback callback);
  void HandleLoadHandwritingModelCall(
      mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
      mojom::MachineLearningService::LoadHandwritingModelCallback callback);
  void HandleLoadWebPlatformHandwritingModelCall(
      mojo::PendingReceiver<web_platform::mojom::HandwritingRecognizer>
          receiver,
      mojom::MachineLearningService::LoadHandwritingModelCallback callback);
  void HandleRecognizeCall(
      mojom::HandwritingRecognitionQueryPtr query,
      mojom::HandwritingRecognizer::RecognizeCallback callback);
  void HandleGetPredictionCall(
      std::vector<web_platform::mojom::HandwritingStrokePtr> strokes,
      web_platform::mojom::HandwritingHintsPtr hints,
      web_platform::mojom::HandwritingRecognizer::GetPredictionCallback
          callback);
  void HandleLoadGrammarCheckerCall(
      mojo::PendingReceiver<mojom::GrammarChecker> receiver,
      mojom::MachineLearningService::LoadGrammarCheckerCallback callback);
  void HandleGrammarCheckerQueryCall(
      mojom::GrammarCheckerQueryPtr query,
      mojom::GrammarChecker::CheckCallback callback);
  void HandleLoadSpeechRecognizerCall(
      mojo::PendingRemote<mojom::SodaClient> soda_client,
      mojo::PendingReceiver<mojom::SodaRecognizer> soda_recognizer,
      mojom::MachineLearningService::LoadSpeechRecognizerCallback callback);
  void HandleLoadTextSuggesterCall(
      mojo::PendingReceiver<mojom::TextSuggester> receiver,
      mojom::TextSuggesterSpecPtr spec,
      mojom::MachineLearningService::LoadTextSuggesterCallback callback);
  void HandleTextSuggesterSuggestCall(
      mojom::TextSuggesterQueryPtr query,
      mojom::TextSuggester::SuggestCallback callback);
  void HandleLoadDocumentScannerCall(
      mojo::PendingReceiver<mojom::DocumentScanner> receiver,
      mojom::MachineLearningService::LoadDocumentScannerCallback callback);
  void HandleDocumentScannerDetectNV12Call(
      base::ReadOnlySharedMemoryRegion nv12_image,
      mojom::DocumentScanner::DetectCornersFromNV12ImageCallback callback);
  void HandleDocumentScannerDetectJPEGCall(
      base::ReadOnlySharedMemoryRegion jpeg_image,
      mojom::DocumentScanner::DetectCornersFromJPEGImageCallback callback);
  void HandleDocumentScannerPostProcessingCall(
      base::ReadOnlySharedMemoryRegion jpeg_image,
      const std::vector<gfx::PointF>& corners,
      mojom::DocumentScanner::DoPostProcessingCallback callback);

  void HandleStopCall();
  void HandleStartCall();
  void HandleMarkDoneCall();

  // Additional receivers bound via `Clone`.
  mojo::ReceiverSet<mojom::MachineLearningService> clone_ml_service_receivers_;

  mojo::Remote<mojom::MachineLearningService> machine_learning_service_;
  mojo::ReceiverSet<mojom::Model> model_receivers_;
  mojo::ReceiverSet<mojom::GraphExecutor> graph_receivers_;
  mojo::ReceiverSet<mojom::TextClassifier> text_classifier_receivers_;
  mojo::ReceiverSet<mojom::HandwritingRecognizer> handwriting_receivers_;
  mojo::ReceiverSet<web_platform::mojom::HandwritingRecognizer>
      web_platform_handwriting_receivers_;
  mojo::ReceiverSet<mojom::GrammarChecker> grammar_checker_receivers_;
  mojo::ReceiverSet<mojom::SodaRecognizer> soda_recognizer_receivers_;
  mojo::ReceiverSet<mojom::TextSuggester> text_suggester_receivers_;
  mojo::ReceiverSet<mojom::DocumentScanner> document_scanner_receivers_;
  mojo::RemoteSet<mojom::SodaClient> soda_client_remotes_;
  mojom::TensorPtr output_tensor_;
  mojom::LoadHandwritingModelResult load_handwriting_model_result_;
  mojom::LoadHandwritingModelResult load_web_platform_handwriting_model_result_;
  mojom::LoadModelResult load_model_result_;
  mojom::LoadModelResult load_text_classifier_result_;
  mojom::LoadModelResult load_soda_result_;
  mojom::CreateGraphExecutorResult create_graph_executor_result_;
  mojom::ExecuteResult execute_result_;
  std::vector<mojom::TextAnnotationPtr> annotate_result_;
  mojom::CodepointSpanPtr suggest_selection_result_;
  std::vector<mojom::TextLanguagePtr> find_languages_result_;
  mojom::HandwritingRecognizerResultPtr handwriting_result_;
  std::vector<web_platform::mojom::HandwritingPredictionPtr>
      web_platform_handwriting_result_;
  mojom::GrammarCheckerResultPtr grammar_checker_result_;
  mojom::TextSuggesterResultPtr text_suggester_result_;
  mojom::DetectCornersResultPtr detect_corners_result_;
  mojom::DoPostProcessingResultPtr do_post_processing_result_;

  bool async_mode_;
  std::vector<base::OnceClosure> pending_calls_;
};

}  // namespace machine_learning
}  // namespace chromeos

#endif  // CHROMEOS_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_FAKE_SERVICE_CONNECTION_H_