Spaces:
ckcl
/
Build error

dnn_space_old / inference.py
ckcl's picture
Upload 12 files
2c82790 verified
raw
history blame
2.86 kB
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
import io
import base64
import os
class DrowsinessDetector:
def __init__(self):
self.model = None
self.input_shape = (64, 64, 3)
def load_model(self, model_path):
"""Load the model from the specified path"""
self.model = tf.keras.models.load_model(model_path)
def preprocess_image(self, image):
"""Preprocess the input image"""
if isinstance(image, str):
# If image is a base64 string
image_data = base64.b64decode(image)
image = Image.open(io.BytesIO(image_data))
image = np.array(image)
elif isinstance(image, bytes):
# If image is raw bytes
image = Image.open(io.BytesIO(image))
image = np.array(image)
# Convert to RGB if needed
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Resize and normalize
image = cv2.resize(image, self.input_shape[:2])
image = image.astype(np.float32) / 255.0
image = np.expand_dims(image, axis=0)
return image
def predict(self, image):
"""Make prediction on the input image"""
if self.model is None:
raise ValueError("Model not loaded. Call load_model() first.")
# Preprocess the image
processed_image = self.preprocess_image(image)
# Make prediction
prediction = self.model.predict(processed_image)
# Return prediction results
return {
"drowsy_probability": float(prediction[0][0]),
"is_drowsy": bool(prediction[0][0] > 0.5)
}
# Create a global instance
detector = DrowsinessDetector()
def load_model():
"""Load the model when the API starts"""
global detector
detector.load_model("model_weights.h5")
def predict(image):
"""API endpoint for prediction"""
try:
result = detector.predict(image)
return {
"status": "success",
"prediction": result
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
# For local testing
if __name__ == "__main__":
# Load model
load_model()
# Test with a sample image
test_image_path = "test_image.jpg" # Replace with your test image
if os.path.exists(test_image_path):
with open(test_image_path, "rb") as f:
image_data = f.read()
result = predict(image_data)
print("Prediction result:", result)