MaHaWo's picture
reformat model impl
97f3fc0
from iSparrow.sparrow_model_base import ModelBase
try:
import tflite_runtime.interpreter as tflite
except ImportError:
import tensorflow.lite as tflite
from iSparrow import utils
from iSparrow import ModelBase
import numpy as np
from pathlib import Path
from scipy.special import softmax
class Model(ModelBase):
"""
Model Implementation of a iSparrow model that uses the google perch tflite model.
Args:
ModelBase (iSparrow.ModelBase): Model base class that provides the interface through which to interact with iSparrow.
"""
def __init__(self, model_path: str, num_threads: int = 1, **kwargs):
"""
__init__ Create a new model instance that uses the google perch tflite converted model.
Args:
model_path (str): path to where the google perch tflite model is stored
num_threads (int, optional): number of threads to use. Defaults to 1.
"""
labels_path = str(Path(model_path) / "labels.txt")
model_path = str(Path(model_path) / "model.tflite")
# base class loads the model and labels
super().__init__(
"google_perch_lite",
model_path,
labels_path,
num_threads=num_threads,
**kwargs
)
# store input and output index to not have to retrieve them each time an inference is made
input_details = self.model.get_input_details()
output_details = self.model.get_output_details()
self.input_layer_index = input_details[0]["index"]
self.output_layer_index = output_details[1]["index"]
def predict(self, sample: np.array) -> np.array:
"""
predict Make inference about the bird species for the preprocessed data passed to this function as arguments.
Args:
data (np.array): list of preprocessed data chunks
Returns:
numpy array: array of probabilities per class
"""
data = np.array([sample], dtype="float32")
self.model.resize_tensor_input(
self.input_layer_index, [len(data), *data[0].shape]
)
self.model.allocate_tensors()
# Make a prediction
self.model.set_tensor(self.input_layer_index, data)
self.model.invoke()
logits = self.model.get_tensor(self.output_layer_index)
confidence = softmax(logits)
return confidence
@classmethod
def from_cfg(cls, sparrow_folder: str, cfg: dict):
"""
from_cfg Create a new instance from a dictionary containing keyword arguments. Usually loaded from a config file.
Args:
sparrow_dir (str): Installation directory of the Sparrow package
cfg (dict): Dictionary containing the keyword arguments
Returns:
Model: New model instance created with the supplied kwargs.
"""
cfg["model_name"] = str(
Path(sparrow_folder) / Path("models") / cfg["model_name"]
)
return cls(**cfg)