Spaces:
Running
Running
| from utils.preprocessing import preprocess_image | |
| import tensorflow as tf | |
| import numpy as np | |
| import os | |
| import logging | |
| import json | |
| class AfroPalmModel: | |
| """ | |
| Class to load the model and make predictions | |
| """ | |
| def __init__(self): | |
| logging.info("Loading classification model") | |
| self.model_path = os.path.dirname(os.path.abspath("ghostnet_model_float32.tflite")) + "/models/ghostnet_model_float32.tflite" | |
| logging.debug(f"Preparing to read from {self.model_path}") | |
| with open(self.model_path, 'rb') as fid: | |
| tflite_model = fid.read() | |
| logging.info("File read successfully") | |
| # Create and allocate the interpreter using the loaded state | |
| self.interpreter = tf.lite.Interpreter(model_content=tflite_model) | |
| self.interpreter.allocate_tensors() | |
| # Retrieve the input and output indices | |
| self.input_index = self.interpreter.get_input_details()[0]["index"] | |
| self.output_index = self.interpreter.get_output_details()[0]["index"] | |
| logging.info("Model loaded successfully") | |
| def predict(self, image_path): | |
| """ | |
| Make a prediction on the image | |
| :param image: image to make prediction on | |
| :return: prediction and confidence score | |
| """ | |
| logging.info("Making prediction") | |
| img = preprocess_image(image_path) | |
| logging.debug(f'Image preprocessed with shape {np.array(img).shape}') | |
| self.interpreter.set_tensor(self.input_index, img) | |
| self.interpreter.invoke() | |
| predictions = list(self.interpreter.get_tensor(self.output_index)[0]) | |
| logging.info("Classification successful") | |
| return predictions.index(max(predictions)), max(predictions) |