|
|
from iSparrow.sparrow_model_base import ModelBase |
|
|
|
|
|
try: |
|
|
import tflite_runtime.interpreter as tflite |
|
|
except ImportError: |
|
|
import tensorflow.lite as tflite |
|
|
|
|
|
from iSparrow import utils |
|
|
from iSparrow import ModelBase |
|
|
|
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from scipy.special import softmax |
|
|
|
|
|
|
|
|
class Model(ModelBase): |
|
|
""" |
|
|
Model Implementation of a iSparrow model that uses the google perch tflite model. |
|
|
|
|
|
Args: |
|
|
ModelBase (iSparrow.ModelBase): Model base class that provides the interface through which to interact with iSparrow. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path: str, num_threads: int = 1, **kwargs): |
|
|
""" |
|
|
__init__ Create a new model instance that uses the google perch tflite converted model. |
|
|
|
|
|
Args: |
|
|
model_path (str): path to where the google perch tflite model is stored |
|
|
num_threads (int, optional): number of threads to use. Defaults to 1. |
|
|
""" |
|
|
labels_path = str(Path(model_path) / "labels.txt") |
|
|
model_path = str(Path(model_path) / "model.tflite") |
|
|
|
|
|
|
|
|
super().__init__( |
|
|
"google_perch_lite", |
|
|
model_path, |
|
|
labels_path, |
|
|
num_threads=num_threads, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
input_details = self.model.get_input_details() |
|
|
|
|
|
output_details = self.model.get_output_details() |
|
|
|
|
|
self.input_layer_index = input_details[0]["index"] |
|
|
|
|
|
self.output_layer_index = output_details[1]["index"] |
|
|
|
|
|
def predict(self, sample: np.array) -> np.array: |
|
|
""" |
|
|
predict Make inference about the bird species for the preprocessed data passed to this function as arguments. |
|
|
|
|
|
Args: |
|
|
data (np.array): list of preprocessed data chunks |
|
|
Returns: |
|
|
numpy array: array of probabilities per class |
|
|
""" |
|
|
data = np.array([sample], dtype="float32") |
|
|
|
|
|
self.model.resize_tensor_input( |
|
|
self.input_layer_index, [len(data), *data[0].shape] |
|
|
) |
|
|
self.model.allocate_tensors() |
|
|
|
|
|
|
|
|
self.model.set_tensor(self.input_layer_index, data) |
|
|
self.model.invoke() |
|
|
|
|
|
logits = self.model.get_tensor(self.output_layer_index) |
|
|
|
|
|
confidence = softmax(logits) |
|
|
|
|
|
return confidence |
|
|
|
|
|
@classmethod |
|
|
def from_cfg(cls, sparrow_folder: str, cfg: dict): |
|
|
""" |
|
|
from_cfg Create a new instance from a dictionary containing keyword arguments. Usually loaded from a config file. |
|
|
|
|
|
Args: |
|
|
sparrow_dir (str): Installation directory of the Sparrow package |
|
|
cfg (dict): Dictionary containing the keyword arguments |
|
|
|
|
|
Returns: |
|
|
Model: New model instance created with the supplied kwargs. |
|
|
""" |
|
|
cfg["model_name"] = str( |
|
|
Path(sparrow_folder) / Path("models") / cfg["model_name"] |
|
|
) |
|
|
|
|
|
return cls(**cfg) |
|
|
|