# ========================================================= # AI DRIVER SAFETY DETECTION SYSTEM # HuggingFace Gradio App # ========================================================= import gradio as gr import numpy as np import cv2 from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.image import img_to_array # ========================================================= # LOAD MODEL # ========================================================= model = load_model("final_driver_state_model.h5") # ========================================================= # CLASS LABELS # IMPORTANT: # Must match training class order exactly # ========================================================= CLASS_NAMES = [ "alert", "sleepy", "slowBlink", "yawning" ] # ========================================================= # RISK LEVELS # ========================================================= RISK_LEVELS = { "alert": "SAFE", "sleepy": "HIGH RISK", "slowBlink": "MEDIUM RISK", "yawning": "LOW RISK" } # ========================================================= # EMOJIS # ========================================================= RISK_EMOJIS = { "SAFE": "🟢", "LOW RISK": "🟡", "MEDIUM RISK": "🟠", "HIGH RISK": "🔴" } # ========================================================= # IMAGE PREPROCESSING # IMPORTANT: # Match training preprocessing # ========================================================= def preprocess_image(image): # ----------------------------------------------------- # Gradio already provides RGB image # DO NOT use cvtColor # ----------------------------------------------------- image = cv2.resize(image, (224, 224)) image = image.astype("float32") / 255.0 image = img_to_array(image) image = np.expand_dims(image, axis=0) return image # ========================================================= # PREDICTION FUNCTION # ========================================================= def predict_driver_state(image): if image is None: return ( "Please upload an image.", {} ) # ===================================================== # PREPROCESS # ===================================================== processed_image = preprocess_image(image) # ===================================================== # PREDICTION # ===================================================== prediction = model.predict( processed_image, verbose=0 ) # ===================================================== # RESULTS # ===================================================== class_index = int(np.argmax(prediction)) predicted_class = CLASS_NAMES[class_index] confidence = float(np.max(prediction)) risk_level = RISK_LEVELS[predicted_class] emoji = RISK_EMOJIS[risk_level] # ===================================================== # CONFIDENCE SCORES # ===================================================== confidence_scores = {} for i, class_name in enumerate(CLASS_NAMES): confidence_scores[class_name] = float( prediction[0][i] ) # ===================================================== # RESULT TEXT # ===================================================== result = f""" 🚗 DRIVER STATE ANALYSIS Prediction: {predicted_class.upper()} Confidence: {confidence:.2f} Risk Level: {emoji} {risk_level} """ return result, confidence_scores # ========================================================= # TITLE & DESCRIPTION # ========================================================= title = "🚗 AI Driver Safety Detection System" description = """ Upload a driver image to analyze fatigue and attention state using Deep Learning. ## Supported Driver States - 🟢 Alert - 🔴 Sleepy - 🟠 Slow Blink - 🟡 Yawning ## AI Features ✅ CNN-Based Driver State Classification ✅ Fatigue Risk Analysis ✅ Deep Learning Inference ✅ Real-Time Prediction Engine """ # ========================================================= # GRADIO INTERFACE # ========================================================= interface = gr.Interface( fn=predict_driver_state, inputs=gr.Image( type="numpy", label="Upload Driver Image" ), outputs=[ gr.Textbox( label="Prediction Result" ), gr.Label( label="Confidence Scores" ) ], title=title, description=description ) # ========================================================= # LAUNCH # ========================================================= interface.launch()