PalmOil-Classification / utils /palmoil_classification.py
Mawube's picture
Use only image urls
ed89f9e unverified
from utils.preprocessing import preprocess_image
import tensorflow as tf
import numpy as np
import os
import logging
import json
class AfroPalmModel:
"""
Class to load the model and make predictions
"""
def __init__(self):
logging.info("Loading classification model")
self.model_path = os.path.dirname(os.path.abspath("ghostnet_model_float32.tflite")) + "/models/ghostnet_model_float32.tflite"
logging.debug(f"Preparing to read from {self.model_path}")
with open(self.model_path, 'rb') as fid:
tflite_model = fid.read()
logging.info("File read successfully")
# Create and allocate the interpreter using the loaded state
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.interpreter.allocate_tensors()
# Retrieve the input and output indices
self.input_index = self.interpreter.get_input_details()[0]["index"]
self.output_index = self.interpreter.get_output_details()[0]["index"]
logging.info("Model loaded successfully")
def predict(self, image_path):
"""
Make a prediction on the image
:param image: image to make prediction on
:return: prediction and confidence score
"""
logging.info("Making prediction")
img = preprocess_image(image_path)
logging.debug(f'Image preprocessed with shape {np.array(img).shape}')
self.interpreter.set_tensor(self.input_index, img)
self.interpreter.invoke()
predictions = list(self.interpreter.get_tensor(self.output_index)[0])
logging.info("Classification successful")
return predictions.index(max(predictions)), max(predictions)