MaHaWo commited on
Commit
9737afc
·
1 Parent(s): 5a0d232

finish implementation of google_perch tflite stuff

Browse files
google_perch_tflite/model.py CHANGED
@@ -46,12 +46,6 @@ class Model(ModelBase):
46
  self.output_layer_index = output_details[1]["index"]
47
 
48
 
49
- def load_species_list(self):
50
- # TODO
51
- pass
52
-
53
-
54
-
55
  def predict(self, sample: np.array) -> np.array:
56
  """
57
  predict Make inference about the bird species for the preprocessed data passed to this function as arguments.
@@ -68,7 +62,7 @@ class Model(ModelBase):
68
  )
69
  self.model.allocate_tensors()
70
 
71
- # Make a prediction (Audio only for now)
72
  self.model.set_tensor(self.input_layer_index, data)
73
  self.model.invoke()
74
 
 
46
  self.output_layer_index = output_details[1]["index"]
47
 
48
 
 
 
 
 
 
 
49
  def predict(self, sample: np.array) -> np.array:
50
  """
51
  predict Make inference about the bird species for the preprocessed data passed to this function as arguments.
 
62
  )
63
  self.model.allocate_tensors()
64
 
65
+ # Make a prediction
66
  self.model.set_tensor(self.input_layer_index, data)
67
  self.model.invoke()
68
 
google_perch_tflite/preprocessor.py CHANGED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import iSparrow.preprocessor_base as ppb
3
+
4
+
5
+ class Preprocessor(ppb.PreprocessorBase):
6
+
7
+ def __init__(self,
8
+ sample_rate: int = 32000,
9
+ sample_secs: float = 5.0,
10
+ resample_type: str = "kaiser_fast",
11
+ **kwargs ):
12
+
13
+ super().__init__(
14
+ "google_perch_tflite",
15
+ sample_rate=sample_rate,
16
+ sample_secs=sample_secs,
17
+ resample_type=resample_type,
18
+ **kwargs
19
+ )
20
+
21
+ def process_audio_data(self, rawdata: np.array)->np.array:
22
+ self.chunks = []
23
+
24
+ # raise when sampling rate is unequal.
25
+ if self.actual_sampling_rate != self.sample_rate:
26
+ raise RuntimeError(
27
+ "Sampling rate is not the desired one. Desired sampling rate: {self.sample_rate}, actual sampling rate: {self.actual_sampling_rate}"
28
+ )
29
+
30
+ frame_length = int(self.sample_secs * self.sample_rate)
31
+ step_length = int(self.sample_secs - self.overlap) * self.sample_rate
32
+
33
+ self.chunks = tf_split_signal_into_chunks(
34
+ rawdata, frame_length, step_length, pad_end=True
35
+ ).numpy()
36
+
37
+ print(
38
+ "process audio data google: complete, read ",
39
+ str(len(self.chunks)),
40
+ "chunks.",
41
+ flush=True
42
+ )
43
+
44
+ return self.chunks
45
+
46
+
47
+ @classmethod
48
+ def from_cfg(cls, cfg: dict):
49
+
50
+ # make sure there are no more than the allowed keyword arguments in the cfg
51
+ allowed = [
52
+ "sample_rate",
53
+ "sample_secs",
54
+ "resample_type",
55
+ "duration",
56
+ "actual_sampling_rate",
57
+ ]
58
+
59
+ if len([key for key in cfg if key not in allowed]) > 0:
60
+ raise RuntimeError("Erroneous keyword arguments in preprocessor config")
61
+
62
+ return cls(**cfg)