dnn_space / inference.py
ckcl's picture
Upload 12 files
2c82790 verified
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
import io
import base64
import os
class DrowsinessDetector:
def __init__(self):
self.model = None
self.input_shape = (64, 64, 3)
def load_model(self, model_path):
"""Load the model from the specified path"""
self.model = tf.keras.models.load_model(model_path)
def preprocess_image(self, image):
"""Preprocess the input image"""
if isinstance(image, str):
# If image is a base64 string
image_data = base64.b64decode(image)
image = Image.open(io.BytesIO(image_data))
image = np.array(image)
elif isinstance(image, bytes):
# If image is raw bytes
image = Image.open(io.BytesIO(image))
image = np.array(image)
# Convert to RGB if needed
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Resize and normalize
image = cv2.resize(image, self.input_shape[:2])
image = image.astype(np.float32) / 255.0
image = np.expand_dims(image, axis=0)
return image
def predict(self, image):
"""Make prediction on the input image"""
if self.model is None:
raise ValueError("Model not loaded. Call load_model() first.")
# Preprocess the image
processed_image = self.preprocess_image(image)
# Make prediction
prediction = self.model.predict(processed_image)
# Return prediction results
return {
"drowsy_probability": float(prediction[0][0]),
"is_drowsy": bool(prediction[0][0] > 0.5)
}
# Create a global instance
detector = DrowsinessDetector()
def load_model():
"""Load the model when the API starts"""
global detector
detector.load_model("model_weights.h5")
def predict(image):
"""API endpoint for prediction"""
try:
result = detector.predict(image)
return {
"status": "success",
"prediction": result
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
# For local testing
if __name__ == "__main__":
# Load model
load_model()
# Test with a sample image
test_image_path = "test_image.jpg" # Replace with your test image
if os.path.exists(test_image_path):
with open(test_image_path, "rb") as f:
image_data = f.read()
result = predict(image_data)
print("Prediction result:", result)