jialitan23 commited on
Commit
4ce6c93
·
verified ·
1 Parent(s): ccdb024

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -89
app.py CHANGED
@@ -2,101 +2,152 @@ import gradio as gr
2
  import joblib
3
  import pandas as pd
4
  import os
 
5
  import traceback
6
 
7
- # --- Load model and artifacts ---
 
 
 
 
 
 
 
 
 
8
  try:
 
9
  app_dir = os.path.dirname(os.path.abspath(__file__))
10
- model = joblib.load(os.path.join(app_dir, "fall_detection_model.joblib"))
11
- scaler = joblib.load(os.path.join(app_dir, "scaler.joblib"))
12
- encoder = joblib.load(os.path.join(app_dir, "encoder.joblib"))
13
- feature_names = joblib.load(os.path.join(app_dir, "feature_names.joblib")) # list of all features used during training
14
- print("Model, scaler, encoder, and feature names loaded successfully.")
15
- except FileNotFoundError as e:
16
- print(f"Error: Missing file: {e}")
17
- exit()
18
-
19
- # Categorical features and their categories (from encoder)
20
- categorical_features = ['Movement Activity', 'Location', 'day_of_week']
21
- categories_map = {cat: encoder.categories_[i].tolist() for i, cat in enumerate(categorical_features)}
22
 
23
- # --- Prediction function ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def predict_fall(movement_activity, location, day_of_week, hour_of_day, minute_of_day, time_since_last_event):
25
- try:
26
- # Initialize all features to 0
27
- data = {f: 0 for f in feature_names}
28
-
29
- # Set one-hot encoded categorical features
30
- data[f'Movement Activity_{movement_activity}'] = 1
31
- data[f'Location_{location}'] = 1
32
- data[f'day_of_week_{day_of_week}'] = 1
33
 
34
- # Set numeric features
35
- data['hour_of_day'] = hour_of_day
36
- data['minute_of_day'] = minute_of_day
37
- data['time_since_last_event'] = time_since_last_event
38
-
39
- # Create DataFrame with all features in correct order
40
- input_df = pd.DataFrame([data])[feature_names]
41
-
42
- # Get columns scaler was trained on
43
- scaler_feature_names = scaler.feature_names_in_
44
-
45
- # Scale only those columns, keep the rest unchanged
46
- scaled_array = scaler.transform(input_df[scaler_feature_names])
47
- input_df.loc[:, scaler_feature_names] = scaled_array
48
-
49
- # Predict using the model on fully prepared DataFrame
50
- pred_proba = model.predict_proba(input_df)[0, 1]
51
- threshold = 0.4
52
- pred_label = "Fall Detected" if pred_proba >= threshold else "No Fall"
53
-
54
- return f"Prediction: {pred_label}\nFall Probability: {pred_proba:.2f}"
 
 
 
 
55
 
56
  except Exception as e:
57
- tb = traceback.format_exc()
58
- print("Error in prediction:\n", tb)
59
- return f"Error: {str(e)}\nCheck server logs for details."
60
-
61
- # --- Build Gradio interface ---
62
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
63
- gr.Markdown("# Fall Detection Model")
64
- gr.Markdown("Provide sensor data to predict fall detection.")
65
-
66
- movement_activity_input = gr.Dropdown(
67
- label="Movement Activity",
68
- choices=categories_map['Movement Activity'],
69
- value=categories_map['Movement Activity'][0]
70
- )
71
- location_input = gr.Dropdown(
72
- label="Location",
73
- choices=categories_map['Location'],
74
- value=categories_map['Location'][0]
75
- )
76
- day_of_week_input = gr.Dropdown(
77
- label="Day of Week",
78
- choices=categories_map['day_of_week'],
79
- value=categories_map['day_of_week'][0]
80
- )
81
- hour_of_day_input = gr.Slider(0, 23, step=1, label="hour_of_day")
82
- minute_of_day_input = gr.Slider(0, 59, step=1, label="minute_of_day")
83
- time_since_last_event_input = gr.Number(label="time_since_last_event")
84
-
85
- prediction_output = gr.Textbox(label="Prediction Result", lines=5)
86
-
87
- predict_button = gr.Button("Run Prediction")
88
- predict_button.click(
89
- fn=predict_fall,
90
- inputs=[
91
- movement_activity_input,
92
- location_input,
93
- day_of_week_input,
94
- hour_of_day_input,
95
- minute_of_day_input,
96
- time_since_last_event_input
97
- ],
98
- outputs=prediction_output
99
- )
100
-
101
- if __name__ == "__main__":
102
- demo.launch()
 
 
 
 
 
 
 
 
 
 
2
  import joblib
3
  import pandas as pd
4
  import os
5
+ import numpy as np
6
  import traceback
7
 
8
+ # --- 1. Define Model and Preprocessing Components ---
9
+ model = None
10
+ scaler = None
11
+ encoder = None
12
+ all_feature_columns = None
13
+ numerical_features = None
14
+ categorical_features = None
15
+ prediction_threshold = 0.5
16
+
17
+ # --- 2. Load the Model and Preprocessing Tools with Error Handling ---
18
  try:
19
+ # Get the directory of the script to correctly locate model files
20
  app_dir = os.path.dirname(os.path.abspath(__file__))
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Load all the necessary model artifacts
23
+ model_path = os.path.join(app_dir, 'fall_detection_model.joblib')
24
+ scaler_path = os.path.join(app_dir, 'scaler.joblib')
25
+ encoder_path = os.path.join(app_dir, 'encoder.joblib')
26
+ feature_names_path = os.path.join(app_dir, 'feature_names.joblib')
27
+ threshold_path = os.path.join(app_dir, 'prediction_threshold.txt')
28
+
29
+ model = joblib.load(model_path)
30
+ scaler = joblib.load(scaler_path)
31
+ encoder = joblib.load(encoder_path)
32
+ all_feature_columns = joblib.load(feature_names_path)
33
+
34
+ # Read the prediction threshold from the text file
35
+ with open(threshold_path, 'r') as f:
36
+ prediction_threshold = float(f.read().strip())
37
+
38
+ print("All model and preprocessing files loaded successfully.")
39
+
40
+ # Dynamically separate numerical from categorical features using the encoder
41
+ # This is the most reliable and foolproof way to ensure column lists match.
42
+ # The get_feature_names_out() method from the encoder returns the names of the
43
+ # one-hot encoded columns. We can then infer the categorical and numerical
44
+ # features from the master list of all features.
45
+ encoded_feature_names = encoder.get_feature_names_out()
46
+ categorical_features_encoded = [col for col in all_feature_columns if col in encoded_feature_names]
47
+ numerical_features = [col for col in all_feature_columns if col not in encoded_feature_names]
48
+
49
+ # Map the original categorical features from the encoder's categories
50
+ categories_map = {
51
+ 'Movement Activity': encoder.categories_[0].tolist(),
52
+ 'Location': encoder.categories_[1].tolist(),
53
+ 'day_of_week': encoder.categories_[2].tolist()
54
+ }
55
+
56
+ except FileNotFoundError as e:
57
+ print(f"Error: A required file was not found. Please ensure all model artifacts "
58
+ f"are in the same directory as this script. Missing file: {e.filename}")
59
+ model = None # Set model to None to prevent app from launching
60
+ except Exception as e:
61
+ print(f"An unexpected error occurred during file loading: {e}\n{traceback.format_exc()}")
62
+ model = None
63
+
64
+ # --- 3. Prediction Function for Gradio ---
65
  def predict_fall(movement_activity, location, day_of_week, hour_of_day, minute_of_day, time_since_last_event):
66
+ """
67
+ Takes user inputs and uses the loaded model to predict the likelihood of a fall.
68
+ """
69
+ if model is None:
70
+ return "Error: Model not loaded. Check server logs for details."
 
 
 
71
 
72
+ try:
73
+ # Create a dictionary to hold the input data with all expected columns.
74
+ input_data = {col: 0.0 for col in all_feature_columns}
75
+
76
+ # Populate the one-hot encoded features with user input
77
+ input_data[f'Movement Activity_{movement_activity}'] = 1.0
78
+ input_data[f'Location_{location}'] = 1.0
79
+ input_data[f'day_of_week_{day_of_week}'] = 1.0
80
+
81
+ # Populate the numerical features with user input
82
+ input_data['hour_of_day'] = float(hour_of_day)
83
+ input_data['minute_of_day'] = float(minute_of_day)
84
+ input_data['time_since_last_event'] = float(time_since_last_event)
85
+
86
+ # Create a DataFrame from the dictionary, ensuring the column order is correct
87
+ input_df = pd.DataFrame([input_data], columns=all_feature_columns)
88
+
89
+ # Apply the scaler to only the numerical columns
90
+ input_df[numerical_features] = scaler.transform(input_df[numerical_features])
91
+
92
+ # Make the prediction
93
+ fall_probability = model.predict_proba(input_df)[:, 1][0]
94
+ prediction = 'Fall Detected' if fall_probability >= prediction_threshold else 'No Fall'
95
+
96
+ return f"Prediction: {prediction}\nFall Probability: {fall_probability:.2f}"
97
 
98
  except Exception as e:
99
+ return f"An error occurred during prediction: {e}\n{traceback.format_exc()}"
100
+
101
+ # --- 4. Gradio Interface Setup ---
102
+ if model is not None:
103
+ # Use the `gr.Blocks` for a more flexible layout.
104
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
105
+ gr.Markdown("# Fall Detection Model")
106
+ gr.Markdown(
107
+ f"This model predicts the likelihood of a fall based on user activity and location. "
108
+ f"A fall is detected if the probability is >= {prediction_threshold}."
109
+ )
110
+
111
+ with gr.Row():
112
+ movement_activity_input = gr.Dropdown(
113
+ label="Movement Activity",
114
+ choices=categories_map['Movement Activity'],
115
+ value=categories_map['Movement Activity'][0]
116
+ )
117
+ location_input = gr.Dropdown(
118
+ label="Location",
119
+ choices=categories_map['Location'],
120
+ value=categories_map['Location'][0]
121
+ )
122
+ day_of_week_input = gr.Dropdown(
123
+ label="Day of Week",
124
+ choices=categories_map['day_of_week'],
125
+ value=categories_map['day_of_week'][0]
126
+ )
127
+
128
+ with gr.Row():
129
+ hour_of_day_input = gr.Slider(0, 23, step=1, label="Hour of Day")
130
+ minute_of_day_input = gr.Slider(0, 59, step=1, label="Minute of Day")
131
+ time_since_last_event_input = gr.Number(label="Time Since Last Event (in seconds)")
132
+
133
+ predict_button = gr.Button("Run Prediction", variant="primary")
134
+ prediction_output = gr.Textbox(label="Prediction Result", lines=3)
135
+
136
+ predict_button.click(
137
+ fn=predict_fall,
138
+ inputs=[
139
+ movement_activity_input,
140
+ location_input,
141
+ day_of_week_input,
142
+ hour_of_day_input,
143
+ minute_of_day_input,
144
+ time_since_last_event_input
145
+ ],
146
+ outputs=prediction_output
147
+ )
148
+
149
+ # --- 5. Launch the Gradio App ---
150
+ if __name__ == "__main__":
151
+ demo.launch()
152
+ else:
153
+ print("Gradio app cannot be launched due to errors during model loading.")