File size: 2,864 Bytes
2c82790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
import io
import base64
import os

class DrowsinessDetector:
    def __init__(self):
        self.model = None
        self.input_shape = (64, 64, 3)
        
    def load_model(self, model_path):
        """Load the model from the specified path"""
        self.model = tf.keras.models.load_model(model_path)
        
    def preprocess_image(self, image):
        """Preprocess the input image"""
        if isinstance(image, str):
            # If image is a base64 string
            image_data = base64.b64decode(image)
            image = Image.open(io.BytesIO(image_data))
            image = np.array(image)
        elif isinstance(image, bytes):
            # If image is raw bytes
            image = Image.open(io.BytesIO(image))
            image = np.array(image)
        
        # Convert to RGB if needed
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        elif image.shape[2] == 4:
            image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
            
        # Resize and normalize
        image = cv2.resize(image, self.input_shape[:2])
        image = image.astype(np.float32) / 255.0
        image = np.expand_dims(image, axis=0)
        
        return image
    
    def predict(self, image):
        """Make prediction on the input image"""
        if self.model is None:
            raise ValueError("Model not loaded. Call load_model() first.")
            
        # Preprocess the image
        processed_image = self.preprocess_image(image)
        
        # Make prediction
        prediction = self.model.predict(processed_image)
        
        # Return prediction results
        return {
            "drowsy_probability": float(prediction[0][0]),
            "is_drowsy": bool(prediction[0][0] > 0.5)
        }

# Create a global instance
detector = DrowsinessDetector()

def load_model():
    """Load the model when the API starts"""
    global detector
    detector.load_model("model_weights.h5")

def predict(image):
    """API endpoint for prediction"""
    try:
        result = detector.predict(image)
        return {
            "status": "success",
            "prediction": result
        }
    except Exception as e:
        return {
            "status": "error",
            "message": str(e)
        }

# For local testing
if __name__ == "__main__":
    # Load model
    load_model()
    
    # Test with a sample image
    test_image_path = "test_image.jpg"  # Replace with your test image
    if os.path.exists(test_image_path):
        with open(test_image_path, "rb") as f:
            image_data = f.read()
        result = predict(image_data)
        print("Prediction result:", result)