File size: 3,026 Bytes
5a0d232
97f3fc0
 
5a0d232
 
 
 
97f3fc0
 
5a0d232
8729d1a
b52f3c1
cf40c2c
5a0d232
97f3fc0
 
5a0d232
 
 
 
 
 
 
97f3fc0
5a0d232
 
 
 
 
 
 
 
 
 
 
 
3dc55b1
5a0d232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9737afc
5a0d232
 
 
 
 
cf40c2c
5a0d232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from iSparrow.sparrow_model_base import ModelBase

try:
    import tflite_runtime.interpreter as tflite
except ImportError:
    import tensorflow.lite as tflite

from iSparrow import utils
from iSparrow import ModelBase

import numpy as np
from pathlib import Path
from scipy.special import softmax


class Model(ModelBase):
    """
    Model Implementation of a iSparrow model that uses the google perch tflite model.

    Args:
        ModelBase (iSparrow.ModelBase): Model base class that provides the interface through which to interact with iSparrow.
    """

    def __init__(self, model_path: str, num_threads: int = 1, **kwargs):
        """
        __init__ Create a new model instance that uses the google perch tflite converted model.

        Args:
            model_path (str): path to where the google perch tflite model is stored
            num_threads (int, optional): number of threads to use. Defaults to 1.
        """
        labels_path = str(Path(model_path) / "labels.txt")
        model_path = str(Path(model_path) / "model.tflite")

        # base class loads the model and labels
        super().__init__(
            "google_perch_lite",
            model_path,
            labels_path,
            num_threads=num_threads,
            **kwargs
        )

        # store input and output index to not have to retrieve them each time an inference is made
        input_details = self.model.get_input_details()

        output_details = self.model.get_output_details()

        self.input_layer_index = input_details[0]["index"]

        self.output_layer_index = output_details[1]["index"]

    def predict(self, sample: np.array) -> np.array:
        """
        predict Make inference about the bird species for the preprocessed data passed to this function as arguments.

        Args:
            data (np.array): list of preprocessed data chunks
        Returns:
            numpy array: array of probabilities per class
        """
        data = np.array([sample], dtype="float32")

        self.model.resize_tensor_input(
            self.input_layer_index, [len(data), *data[0].shape]
        )
        self.model.allocate_tensors()

        # Make a prediction
        self.model.set_tensor(self.input_layer_index, data)
        self.model.invoke()

        logits = self.model.get_tensor(self.output_layer_index)

        confidence = softmax(logits)

        return confidence

    @classmethod
    def from_cfg(cls, sparrow_folder: str, cfg: dict):
        """
        from_cfg Create a new instance from a dictionary containing keyword arguments. Usually loaded from a config file.

        Args:
            sparrow_dir (str): Installation directory of the Sparrow package
            cfg (dict): Dictionary containing the keyword arguments

        Returns:
            Model: New model instance created with the supplied kwargs.
        """
        cfg["model_name"] = str(
            Path(sparrow_folder) / Path("models") / cfg["model_name"]
        )

        return cls(**cfg)