from utils.preprocessing import preprocess_image import tensorflow as tf import numpy as np import os import logging import json class AfroPalmModel: """ Class to load the model and make predictions """ def __init__(self): logging.info("Loading classification model") self.model_path = os.path.dirname(os.path.abspath("ghostnet_model_float32.tflite")) + "/models/ghostnet_model_float32.tflite" logging.debug(f"Preparing to read from {self.model_path}") with open(self.model_path, 'rb') as fid: tflite_model = fid.read() logging.info("File read successfully") # Create and allocate the interpreter using the loaded state self.interpreter = tf.lite.Interpreter(model_content=tflite_model) self.interpreter.allocate_tensors() # Retrieve the input and output indices self.input_index = self.interpreter.get_input_details()[0]["index"] self.output_index = self.interpreter.get_output_details()[0]["index"] logging.info("Model loaded successfully") def predict(self, image_path): """ Make a prediction on the image :param image: image to make prediction on :return: prediction and confidence score """ logging.info("Making prediction") img = preprocess_image(image_path) logging.debug(f'Image preprocessed with shape {np.array(img).shape}') self.interpreter.set_tensor(self.input_index, img) self.interpreter.invoke() predictions = list(self.interpreter.get_tensor(self.output_index)[0]) logging.info("Classification successful") return predictions.index(max(predictions)), max(predictions)