Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import joblib | |
| import numpy as np | |
| # ----------------------------- | |
| # Load models and training columns | |
| # ----------------------------- | |
| # Classification models | |
| rf_model = joblib.load("main/random_forest_model.pkl") | |
| xgb_clf_model = joblib.load("main/xgboost_model.pkl") | |
| gbr_clf_model = joblib.load("main/gradient_boosting_model.pkl") | |
| training_columns_clf = joblib.load("main/training_clm.pkl") | |
| # Regression models | |
| ridge_model = joblib.load("main/ridge_model.pkl") | |
| xgb_reg_model = joblib.load("main/xgb_model.pkl") | |
| gbr_reg_model = joblib.load("main/gbr_model.pkl") | |
| training_columns_reg = joblib.load("main/training_columns.pkl") | |
| # ----------------------------- | |
| # Preprocessing functions | |
| # ----------------------------- | |
| def preprocess_classification(df): | |
| categorical_cols = ['ORIGIN', 'DEST', 'CARRIER', 'TAIL_NUM', | |
| 'DEP_TIME_BLK', 'DEST_STATE_ABR', 'ORIGIN_CITY_NAME', | |
| 'DEST_CITY_NAME', 'route'] | |
| df_encoded = pd.get_dummies(df, columns=categorical_cols) | |
| df_encoded = df_encoded.reindex(columns=training_columns_clf, fill_value=0) | |
| return df_encoded | |
| def preprocess_regression(df): | |
| df_encoded = pd.get_dummies(df, columns=['time_of_day', 'wind_dir_bucket']) | |
| df_encoded = df_encoded.reindex(columns=training_columns_reg, fill_value=0) | |
| return df_encoded | |
| # ----------------------------- | |
| # Delay category helper | |
| # ----------------------------- | |
| def categorize_delay(minutes): | |
| if minutes < 15: | |
| return "Delay not considered less then 15mins" | |
| elif 15 <= minutes < 20: | |
| return "Delay is Minimum" | |
| elif 20 <= minutes < 30: | |
| return "Flight is moderately delayed" | |
| elif 30 <= minutes < 60: | |
| return "Flight is highly delayed" | |
| else: | |
| return "Flight is delayed too much" | |
| # ----------------------------- | |
| # Prediction functions | |
| # ----------------------------- | |
| def predict_classification(YEAR, MONTH, DAY_OF_MONTH, DAY_OF_WEEK, | |
| ORIGIN, DEST, CARRIER, TAIL_NUM, DEP_TIME_BLK, | |
| DEST_STATE_ABR, ORIGIN_CITY_NAME, DEST_CITY_NAME): | |
| # Auto-generate route | |
| route = f"{ORIGIN}_{DEST}" | |
| data = { | |
| 'YEAR': int(YEAR), | |
| 'MONTH': int(MONTH), | |
| 'DAY_OF_MONTH': int(DAY_OF_MONTH), | |
| 'DAY_OF_WEEK': int(DAY_OF_WEEK), | |
| 'ORIGIN': ORIGIN, | |
| 'DEST': DEST, | |
| 'CARRIER': CARRIER, | |
| 'TAIL_NUM': TAIL_NUM, | |
| 'DEP_TIME_BLK': DEP_TIME_BLK, | |
| 'DEST_STATE_ABR': DEST_STATE_ABR, | |
| 'ORIGIN_CITY_NAME': ORIGIN_CITY_NAME, | |
| 'DEST_CITY_NAME': DEST_CITY_NAME, | |
| 'route': route | |
| } | |
| df_input = pd.DataFrame([data]) | |
| X = preprocess_classification(df_input) | |
| pred_rf = rf_model.predict(X)[0] | |
| pred_xgb = xgb_clf_model.predict(X)[0] | |
| pred_gbr = gbr_clf_model.predict(X)[0] | |
| prob_rf = rf_model.predict_proba(X)[0][1] if hasattr(rf_model, "predict_proba") else None | |
| prob_xgb = xgb_clf_model.predict_proba(X)[0][1] if hasattr(xgb_clf_model, "predict_proba") else None | |
| prob_gbr = gbr_clf_model.predict_proba(X)[0][1] if hasattr(gbr_clf_model, "predict_proba") else None | |
| majority_vote = int(np.round(np.mean([pred_rf, pred_xgb, pred_gbr]))) | |
| return { | |
| "Random Forest Prediction": int(pred_rf), | |
| "Random Forest Prob": round(prob_rf, 3) if prob_rf is not None else None, | |
| "XGBoost Prediction": int(pred_xgb), | |
| "XGBoost Prob": round(prob_xgb, 3) if prob_xgb is not None else None, | |
| "Gradient Boosting Prediction": int(pred_gbr), | |
| "Gradient Boosting Prob": round(prob_gbr, 3) if prob_gbr is not None else None, | |
| "Majority Vote": majority_vote | |
| } | |
| def predict_regression_with_check(DEP_DELAY, DEP_DELAY_NEW, DEP_DEL15, DEP_DELAY_GROUP, | |
| temp, prcp, wspd, wdir, bad_weather, wind_dir_bucket, | |
| time_of_day, is_weekend): | |
| # If not delayed, skip regression | |
| if int(DEP_DEL15) == 0: | |
| return { | |
| "Status": "No delay predicted", | |
| "Delay Category": None | |
| } | |
| data = { | |
| 'DEP_DELAY': float(DEP_DELAY), | |
| 'DEP_DELAY_NEW': float(DEP_DELAY_NEW), | |
| 'DEP_DEL15': int(DEP_DEL15), | |
| 'DEP_DELAY_GROUP': int(DEP_DELAY_GROUP), | |
| 'temp': float(temp), | |
| 'prcp': float(prcp), | |
| 'wspd': float(wspd), | |
| 'wdir': float(wdir), | |
| 'bad_weather': int(bad_weather), | |
| 'wind_dir_bucket': wind_dir_bucket, | |
| 'time_of_day': time_of_day, | |
| 'is_weekend': int(is_weekend) | |
| } | |
| df_input = pd.DataFrame([data]) | |
| X = preprocess_regression(df_input) | |
| pred_ridge = ridge_model.predict(X)[0] | |
| pred_xgb = xgb_reg_model.predict(X)[0] | |
| pred_gbr = gbr_reg_model.predict(X)[0] | |
| max_pred = max(pred_ridge, pred_xgb, pred_gbr) | |
| category = categorize_delay(max_pred) | |
| return { | |
| "Ridge Prediction": round(pred_ridge, 2), | |
| "XGBoost Prediction": round(pred_xgb, 2), | |
| "Gradient Boosting Prediction": round(pred_gbr, 2), | |
| "Max Prediction": round(max_pred, 2), | |
| "Delay Category": category | |
| } | |
| # ----------------------------- | |
| # Gradio Interface | |
| # ----------------------------- | |
| classification_inputs = [ | |
| gr.Number(label="YEAR"), | |
| gr.Number(label="MONTH"), | |
| gr.Number(label="DAY_OF_MONTH"), | |
| gr.Number(label="DAY_OF_WEEK (1=Mon ... 7=Sun)"), | |
| gr.Textbox(label="Origin Airport Code"), | |
| gr.Textbox(label="Destination Airport Code"), | |
| gr.Textbox(label="Carrier Code"), | |
| gr.Textbox(label="Tail Number"), | |
| gr.Textbox(label="Departure Time Block (e.g., 0600-0659)"), | |
| gr.Textbox(label="Destination State Abbreviation"), | |
| gr.Textbox(label="Origin City Name"), | |
| gr.Textbox(label="Destination City Name") | |
| ] | |
| regression_inputs = [ | |
| gr.Number(label="DEP_DELAY"), | |
| gr.Number(label="DEP_DELAY_NEW"), | |
| gr.Number(label="DEP_DEL15 (0 or 1)"), | |
| gr.Number(label="DEP_DELAY_GROUP"), | |
| gr.Number(label="Temperature"), | |
| gr.Number(label="Precipitation"), | |
| gr.Number(label="Wind Speed"), | |
| gr.Number(label="Wind Direction"), | |
| gr.Number(label="Bad Weather (0 or 1)"), | |
| gr.Textbox(label="Wind Dir Bucket (North/South/East/West/etc.)"), | |
| gr.Textbox(label="Time of Day (Morning/Afternoon/Evening/Night)"), | |
| gr.Number(label="Is Weekend (0 or 1)") | |
| ] | |
| classification_tab = gr.Interface( | |
| fn=predict_classification, | |
| inputs=classification_inputs, | |
| outputs="json", | |
| title="Flight Delay Classification", | |
| description="Predict delay classification using Random Forest, XGBoost, and Gradient Boosting." | |
| ) | |
| regression_tab = gr.Interface( | |
| fn=predict_regression_with_check, | |
| inputs=regression_inputs, | |
| outputs="json", | |
| title="Flight Delay Regression (Conditional)", | |
| description="Predict arrival delay in minutes only if DEP_DEL15=1, with categorized output." | |
| ) | |
| demo = gr.TabbedInterface([classification_tab, regression_tab], | |
| ["Classification", "Regression"]) | |
| if __name__ == "__main__": | |
| demo.launch() |