|
|
import tensorflow as tf
|
|
|
import numpy as np
|
|
|
import cv2
|
|
|
from PIL import Image
|
|
|
import io
|
|
|
import base64
|
|
|
import os
|
|
|
|
|
|
class DrowsinessDetector:
|
|
|
def __init__(self):
|
|
|
self.model = None
|
|
|
self.input_shape = (64, 64, 3)
|
|
|
|
|
|
def load_model(self, model_path):
|
|
|
"""Load the model from the specified path"""
|
|
|
self.model = tf.keras.models.load_model(model_path)
|
|
|
|
|
|
def preprocess_image(self, image):
|
|
|
"""Preprocess the input image"""
|
|
|
if isinstance(image, str):
|
|
|
|
|
|
image_data = base64.b64decode(image)
|
|
|
image = Image.open(io.BytesIO(image_data))
|
|
|
image = np.array(image)
|
|
|
elif isinstance(image, bytes):
|
|
|
|
|
|
image = Image.open(io.BytesIO(image))
|
|
|
image = np.array(image)
|
|
|
|
|
|
|
|
|
if len(image.shape) == 2:
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
|
|
elif image.shape[2] == 4:
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
|
|
|
|
|
|
|
|
image = cv2.resize(image, self.input_shape[:2])
|
|
|
image = image.astype(np.float32) / 255.0
|
|
|
image = np.expand_dims(image, axis=0)
|
|
|
|
|
|
return image
|
|
|
|
|
|
def predict(self, image):
|
|
|
"""Make prediction on the input image"""
|
|
|
if self.model is None:
|
|
|
raise ValueError("Model not loaded. Call load_model() first.")
|
|
|
|
|
|
|
|
|
processed_image = self.preprocess_image(image)
|
|
|
|
|
|
|
|
|
prediction = self.model.predict(processed_image)
|
|
|
|
|
|
|
|
|
return {
|
|
|
"drowsy_probability": float(prediction[0][0]),
|
|
|
"is_drowsy": bool(prediction[0][0] > 0.5)
|
|
|
}
|
|
|
|
|
|
|
|
|
detector = DrowsinessDetector()
|
|
|
|
|
|
def load_model():
|
|
|
"""Load the model when the API starts"""
|
|
|
global detector
|
|
|
detector.load_model("model_weights.h5")
|
|
|
|
|
|
def predict(image):
|
|
|
"""API endpoint for prediction"""
|
|
|
try:
|
|
|
result = detector.predict(image)
|
|
|
return {
|
|
|
"status": "success",
|
|
|
"prediction": result
|
|
|
}
|
|
|
except Exception as e:
|
|
|
return {
|
|
|
"status": "error",
|
|
|
"message": str(e)
|
|
|
}
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
load_model()
|
|
|
|
|
|
|
|
|
test_image_path = "test_image.jpg"
|
|
|
if os.path.exists(test_image_path):
|
|
|
with open(test_image_path, "rb") as f:
|
|
|
image_data = f.read()
|
|
|
result = predict(image_data)
|
|
|
print("Prediction result:", result) |