asrcoddeploy commited on
Commit
fa59a09
ยท
verified ยท
1 Parent(s): 3360661

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -34
app.py CHANGED
@@ -17,7 +17,9 @@ from tensorflow.keras.preprocessing.image import img_to_array
17
  model = load_model("final_driver_state_model.h5")
18
 
19
  # =========================================================
20
- # CLASS NAMES
 
 
21
  # =========================================================
22
 
23
  CLASS_NAMES = [
@@ -39,7 +41,7 @@ RISK_LEVELS = {
39
  }
40
 
41
  # =========================================================
42
- # COLORS
43
  # =========================================================
44
 
45
  RISK_EMOJIS = {
@@ -50,24 +52,17 @@ RISK_EMOJIS = {
50
  }
51
 
52
  # =========================================================
53
- # PREDICTION FUNCTION
 
 
54
  # =========================================================
55
 
56
- def predict_driver_state(image):
57
-
58
- # =====================================================
59
- # VALIDATION
60
- # =====================================================
61
 
62
- if image is None:
63
- return "Please upload an image.", None
64
-
65
- # =====================================================
66
- # PREPROCESSING
67
- # IMPORTANT:
68
- # Gradio already gives RGB image
69
- # DO NOT use cvtColor here
70
- # =====================================================
71
 
72
  image = cv2.resize(image, (224, 224))
73
 
@@ -77,13 +72,41 @@ def predict_driver_state(image):
77
 
78
  image = np.expand_dims(image, axis=0)
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # =====================================================
81
  # PREDICTION
82
  # =====================================================
83
 
84
- prediction = model.predict(image, verbose=0)
 
 
 
85
 
86
- class_index = np.argmax(prediction)
 
 
 
 
87
 
88
  predicted_class = CLASS_NAMES[class_index]
89
 
@@ -125,13 +148,7 @@ Risk Level:
125
  return result, confidence_scores
126
 
127
  # =========================================================
128
- # EXAMPLES
129
- # =========================================================
130
-
131
- examples = []
132
-
133
- # =========================================================
134
- # UI
135
  # =========================================================
136
 
137
  title = "๐Ÿš— AI Driver Safety Detection System"
@@ -153,7 +170,7 @@ Upload a driver image to analyze fatigue and attention state using Deep Learning
153
  """
154
 
155
  # =========================================================
156
- # INTERFACE
157
  # =========================================================
158
 
159
  interface = gr.Interface(
@@ -178,13 +195,7 @@ interface = gr.Interface(
178
 
179
  title=title,
180
 
181
- description=description,
182
-
183
- examples=examples,
184
-
185
- theme="soft",
186
-
187
- allow_flagging="never"
188
  )
189
 
190
  # =========================================================
 
17
  model = load_model("final_driver_state_model.h5")
18
 
19
  # =========================================================
20
+ # CLASS LABELS
21
+ # IMPORTANT:
22
+ # Must match training class order exactly
23
  # =========================================================
24
 
25
  CLASS_NAMES = [
 
41
  }
42
 
43
  # =========================================================
44
+ # EMOJIS
45
  # =========================================================
46
 
47
  RISK_EMOJIS = {
 
52
  }
53
 
54
  # =========================================================
55
+ # IMAGE PREPROCESSING
56
+ # IMPORTANT:
57
+ # Match training preprocessing
58
  # =========================================================
59
 
60
+ def preprocess_image(image):
 
 
 
 
61
 
62
+ # -----------------------------------------------------
63
+ # Gradio already provides RGB image
64
+ # DO NOT use cvtColor
65
+ # -----------------------------------------------------
 
 
 
 
 
66
 
67
  image = cv2.resize(image, (224, 224))
68
 
 
72
 
73
  image = np.expand_dims(image, axis=0)
74
 
75
+ return image
76
+
77
+ # =========================================================
78
+ # PREDICTION FUNCTION
79
+ # =========================================================
80
+
81
+ def predict_driver_state(image):
82
+
83
+ if image is None:
84
+
85
+ return (
86
+ "Please upload an image.",
87
+ {}
88
+ )
89
+
90
+ # =====================================================
91
+ # PREPROCESS
92
+ # =====================================================
93
+
94
+ processed_image = preprocess_image(image)
95
+
96
  # =====================================================
97
  # PREDICTION
98
  # =====================================================
99
 
100
+ prediction = model.predict(
101
+ processed_image,
102
+ verbose=0
103
+ )
104
 
105
+ # =====================================================
106
+ # RESULTS
107
+ # =====================================================
108
+
109
+ class_index = int(np.argmax(prediction))
110
 
111
  predicted_class = CLASS_NAMES[class_index]
112
 
 
148
  return result, confidence_scores
149
 
150
  # =========================================================
151
+ # TITLE & DESCRIPTION
 
 
 
 
 
 
152
  # =========================================================
153
 
154
  title = "๐Ÿš— AI Driver Safety Detection System"
 
170
  """
171
 
172
  # =========================================================
173
+ # GRADIO INTERFACE
174
  # =========================================================
175
 
176
  interface = gr.Interface(
 
195
 
196
  title=title,
197
 
198
+ description=description
 
 
 
 
 
 
199
  )
200
 
201
  # =========================================================