ATOA_DDobj / app.py
asrcoddeploy's picture
Update app.py
fa59a09 verified
# =========================================================
# AI DRIVER SAFETY DETECTION SYSTEM
# HuggingFace Gradio App
# =========================================================
import gradio as gr
import numpy as np
import cv2
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
# =========================================================
# LOAD MODEL
# =========================================================
model = load_model("final_driver_state_model.h5")
# =========================================================
# CLASS LABELS
# IMPORTANT:
# Must match training class order exactly
# =========================================================
CLASS_NAMES = [
"alert",
"sleepy",
"slowBlink",
"yawning"
]
# =========================================================
# RISK LEVELS
# =========================================================
RISK_LEVELS = {
"alert": "SAFE",
"sleepy": "HIGH RISK",
"slowBlink": "MEDIUM RISK",
"yawning": "LOW RISK"
}
# =========================================================
# EMOJIS
# =========================================================
RISK_EMOJIS = {
"SAFE": "🟒",
"LOW RISK": "🟑",
"MEDIUM RISK": "🟠",
"HIGH RISK": "πŸ”΄"
}
# =========================================================
# IMAGE PREPROCESSING
# IMPORTANT:
# Match training preprocessing
# =========================================================
def preprocess_image(image):
# -----------------------------------------------------
# Gradio already provides RGB image
# DO NOT use cvtColor
# -----------------------------------------------------
image = cv2.resize(image, (224, 224))
image = image.astype("float32") / 255.0
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
return image
# =========================================================
# PREDICTION FUNCTION
# =========================================================
def predict_driver_state(image):
if image is None:
return (
"Please upload an image.",
{}
)
# =====================================================
# PREPROCESS
# =====================================================
processed_image = preprocess_image(image)
# =====================================================
# PREDICTION
# =====================================================
prediction = model.predict(
processed_image,
verbose=0
)
# =====================================================
# RESULTS
# =====================================================
class_index = int(np.argmax(prediction))
predicted_class = CLASS_NAMES[class_index]
confidence = float(np.max(prediction))
risk_level = RISK_LEVELS[predicted_class]
emoji = RISK_EMOJIS[risk_level]
# =====================================================
# CONFIDENCE SCORES
# =====================================================
confidence_scores = {}
for i, class_name in enumerate(CLASS_NAMES):
confidence_scores[class_name] = float(
prediction[0][i]
)
# =====================================================
# RESULT TEXT
# =====================================================
result = f"""
πŸš— DRIVER STATE ANALYSIS
Prediction:
{predicted_class.upper()}
Confidence:
{confidence:.2f}
Risk Level:
{emoji} {risk_level}
"""
return result, confidence_scores
# =========================================================
# TITLE & DESCRIPTION
# =========================================================
title = "πŸš— AI Driver Safety Detection System"
description = """
Upload a driver image to analyze fatigue and attention state using Deep Learning.
## Supported Driver States
- 🟒 Alert
- πŸ”΄ Sleepy
- 🟠 Slow Blink
- 🟑 Yawning
## AI Features
βœ… CNN-Based Driver State Classification
βœ… Fatigue Risk Analysis
βœ… Deep Learning Inference
βœ… Real-Time Prediction Engine
"""
# =========================================================
# GRADIO INTERFACE
# =========================================================
interface = gr.Interface(
fn=predict_driver_state,
inputs=gr.Image(
type="numpy",
label="Upload Driver Image"
),
outputs=[
gr.Textbox(
label="Prediction Result"
),
gr.Label(
label="Confidence Scores"
)
],
title=title,
description=description
)
# =========================================================
# LAUNCH
# =========================================================
interface.launch()