vansh0003 commited on
Commit
fa57b58
·
verified ·
1 Parent(s): c99d243

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -0
app.py CHANGED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import joblib
4
+ import numpy as np
5
+ from sklearn.impute import SimpleImputer
6
+
7
+ # -----------------------------
8
+ # Load new tuned classification model package
9
+ # -----------------------------
10
+ # This file should be created from your training script:
11
+ # joblib.dump({"model": ensemble, "threshold": best_threshold, "columns": list(X_train.columns)}, "main/final_delay_model.pkl")
12
+ model_package = joblib.load("main/final_delay_model.pkl")
13
+ ensemble_model = model_package["model"]
14
+ best_threshold = model_package["threshold"]
15
+ reference_columns = model_package["columns"]
16
+
17
+ # -----------------------------
18
+ # Load regression models and training columns
19
+ # -----------------------------
20
+ ridge_model = joblib.load("main/ridge_model.pkl")
21
+ xgb_reg_model = joblib.load("main/xgb_model.pkl")
22
+ gbr_reg_model = joblib.load("main/gbr_model.pkl")
23
+ training_columns_reg = joblib.load("main/training_columns.pkl")
24
+
25
+ # -----------------------------
26
+ # Preprocessing for classification
27
+ # -----------------------------
28
+ def preprocess_classification(df):
29
+ categorical_cols = ['UNIQUE_CARRIER', 'CARRIER', 'ORIGIN', 'DEST',
30
+ 'ORIGIN_STATE_ABR', 'DEST_STATE_ABR',
31
+ 'DEP_TIME_BLK', 'ARR_TIME_BLK']
32
+ df_encoded = pd.get_dummies(df, columns=categorical_cols)
33
+
34
+ # Add missing columns from training
35
+ for col in reference_columns:
36
+ if col not in df_encoded.columns:
37
+ df_encoded[col] = 0
38
+
39
+ # Reorder columns
40
+ df_encoded = df_encoded[reference_columns]
41
+
42
+ # Impute missing values
43
+ imputer = SimpleImputer(strategy='median')
44
+ df_encoded = pd.DataFrame(imputer.fit_transform(df_encoded), columns=df_encoded.columns)
45
+
46
+ return df_encoded
47
+
48
+ # -----------------------------
49
+ # Preprocessing for regression
50
+ # -----------------------------
51
+ def preprocess_regression(df):
52
+ df_encoded = pd.get_dummies(df, columns=['time_of_day', 'wind_dir_bucket'])
53
+ df_encoded = df_encoded.reindex(columns=training_columns_reg, fill_value=0)
54
+ return df_encoded
55
+
56
+ # -----------------------------
57
+ # Delay category helper
58
+ # -----------------------------
59
+ def categorize_delay(minutes):
60
+ if minutes < 15:
61
+ return "Delay not considered less than 15 mins"
62
+ elif 15 <= minutes < 20:
63
+ return "Delay is Minimum"
64
+ elif 20 <= minutes < 30:
65
+ return "Flight is moderately delayed"
66
+ elif 30 <= minutes < 60:
67
+ return "Flight is highly delayed"
68
+ else:
69
+ return "Flight is delayed too much"
70
+
71
+ # -----------------------------
72
+ # Classification prediction function
73
+ # -----------------------------
74
+ def predict_classification(YEAR, MONTH, DAY_OF_MONTH, DAY_OF_WEEK,
75
+ ORIGIN, DEST, CARRIER,
76
+ ORIGIN_STATE_ABR, DEST_STATE_ABR,
77
+ DEP_TIME_BLK, ARR_TIME_BLK,
78
+ temp, prcp, wspd, wdir, route_delay_rate):
79
+
80
+ data = {
81
+ 'YEAR': int(YEAR),
82
+ 'MONTH': int(MONTH),
83
+ 'DAY_OF_MONTH': int(DAY_OF_MONTH),
84
+ 'DAY_OF_WEEK': int(DAY_OF_WEEK),
85
+ 'UNIQUE_CARRIER': CARRIER,
86
+ 'CARRIER': CARRIER,
87
+ 'ORIGIN': ORIGIN,
88
+ 'DEST': DEST,
89
+ 'ORIGIN_STATE_ABR': ORIGIN_STATE_ABR,
90
+ 'DEST_STATE_ABR': DEST_STATE_ABR,
91
+ 'DEP_TIME_BLK': DEP_TIME_BLK,
92
+ 'ARR_TIME_BLK': ARR_TIME_BLK,
93
+ 'temp': float(temp),
94
+ 'prcp': float(prcp),
95
+ 'wspd': float(wspd),
96
+ 'wdir': float(wdir),
97
+ 'route_delay_rate': float(route_delay_rate)
98
+ }
99
+
100
+ df_input = pd.DataFrame([data])
101
+ X = preprocess_classification(df_input)
102
+
103
+ proba = ensemble_model.predict_proba(X)[0][1]
104
+ pred = int(proba >= best_threshold)
105
+
106
+ return {
107
+ "Prediction": "Delayed" if pred == 1 else "On Time",
108
+ "Confidence": round(proba, 3),
109
+ "Threshold": round(best_threshold, 3)
110
+ }
111
+
112
+ # -----------------------------
113
+ # Regression prediction function (unchanged)
114
+ # -----------------------------
115
+ def predict_regression_with_check(DEP_DELAY, DEP_DELAY_NEW, DEP_DEL15, DEP_DELAY_GROUP,
116
+ temp, prcp, wspd, wdir, bad_weather, wind_dir_bucket,
117
+ time_of_day, is_weekend):
118
+ if int(DEP_DEL15) == 0:
119
+ return {
120
+ "Status": "No delay predicted",
121
+ "Delay Category": None
122
+ }
123
+
124
+ data = {
125
+ 'DEP_DELAY': float(DEP_DELAY),
126
+ 'DEP_DELAY_NEW': float(DEP_DELAY_NEW),
127
+ 'DEP_DEL15': int(DEP_DEL15),
128
+ 'DEP_DELAY_GROUP': int(DEP_DELAY_GROUP),
129
+ 'temp': float(temp),
130
+ 'prcp': float(prcp),
131
+ 'wspd': float(wspd),
132
+ 'wdir': float(wdir),
133
+ 'bad_weather': int(bad_weather),
134
+ 'wind_dir_bucket': wind_dir_bucket,
135
+ 'time_of_day': time_of_day,
136
+ 'is_weekend': int(is_weekend)
137
+ }
138
+ df_input = pd.DataFrame([data])
139
+ X = preprocess_regression(df_input)
140
+
141
+ pred_ridge = ridge_model.predict(X)[0]
142
+ pred_xgb = xgb_reg_model.predict(X)[0]
143
+ pred_gbr = gbr_reg_model.predict(X)[0]
144
+
145
+ max_pred = max(pred_ridge, pred_xgb, pred_gbr)
146
+ category = categorize_delay(max_pred)
147
+
148
+ return {
149
+ "Ridge Prediction": round(pred_ridge, 2),
150
+ "XGBoost Prediction": round(pred_xgb, 2),
151
+ "Gradient Boosting Prediction": round(pred_gbr, 2),
152
+ "Max Prediction": round(max_pred, 2),
153
+ "Delay Category": category
154
+ }
155
+
156
+ # -----------------------------
157
+ # Gradio Interface
158
+ # -----------------------------
159
+ classification_inputs = [
160
+ gr.Number(label="YEAR"),
161
+ gr.Number(label="MONTH"),
162
+ gr.Number(label="DAY_OF_MONTH"),
163
+ gr.Number(label="DAY_OF_WEEK (1=Mon ... 7=Sun)"),
164
+ gr.Textbox(label="Origin Airport Code"),
165
+ gr.Textbox(label="Destination Airport Code"),
166
+ gr.Textbox(label="Carrier Code"),
167
+ gr.Textbox(label="Origin State Abbreviation"),
168
+ gr.Textbox(label="Destination State Abbreviation"),
169
+ gr.Textbox(label="Departure Time Block (e.g., 0600-0659)"),
170
+ gr.Textbox(label="Arrival Time Block (e.g., 0900-0959)"),
171
+ gr.Number(label="Temperature"),
172
+ gr.Number(label="Precipitation"),
173
+ gr.Number(label="Wind Speed"),
174
+ gr.Number(label="Wind Direction"),
175
+ gr.Number(label="Route Delay Rate (historical)")
176
+ ]
177
+
178
+ regression_inputs = [
179
+ gr.Number(label="DEP_DELAY"),
180
+ gr.Number(label="DEP_DELAY_NEW"),
181
+ gr.Number(label="DEP_DEL15 (0 or 1)"),
182
+ gr.Number(label="DEP_DELAY_GROUP"),
183
+ gr.Number(label="Temperature"),
184
+ gr.Number(label="Precipitation"),
185
+ gr.Number(label="Wind Speed"),
186
+ gr.Number(label="Wind Direction"),
187
+ gr.Number(label="Bad Weather (0 or 1)"),
188
+ gr.Textbox(label="Wind Dir Bucket (North/South/East/West/etc.)"),
189
+ gr.Textbox(label="Time of Day (Morning/Afternoon/Evening/Night)"),
190
+ gr.Number(label="Is Weekend (0 or 1)")
191
+ ]
192
+
193
+ classification_tab = gr.Interface(
194
+ fn=predict_classification,
195
+ inputs=classification_inputs,
196
+ outputs="json",
197
+ title="Flight Delay Classification (Tuned Ensemble)",
198
+ description="Predict delay classification using the tuned ensemble model with threshold optimization."
199
+ )
200
+
201
+ regression_tab = gr.Interface(
202
+ fn=predict_regression_with_check,
203
+ inputs=regression_inputs,
204
+ outputs="json",
205
+ title="Flight Delay Regression (Conditional)",
206
+ description="Predict arrival delay in minutes only if DEP_DEL15=1, with categorized output."
207
+ )
208
+
209
+ demo = gr.TabbedInterface([classification_tab, regression_tab],
210
+ ["Classification", "Regression"])
211
+
212
+ if __name__ == "__main__":
213
+ demo.launch()