Flight / app.py
vansh0003's picture
Update app.py
f58ebc3 verified
import joblib
import pandas as pd
import gradio as gr
# --- Given weights ---
weights = {
"GradientBoosting": 0.239,
"RandomForest": 0.573,
"XGBoost": 0.188
}
# --- Given model paths (place models in the same directory as app.py for HF Spaces) ---
model_paths = {
"GradientBoosting": "main/GradientBoosting_model.pkl",
"RandomForest": "main/RandomForest_model.pkl",
"XGBoost": "main/xgb_model.pkl"
}
# --- Load models ---
models = {name: joblib.load(path) for name, path in model_paths.items()}
# --- Prediction function ---
def predict(
dep_delay_x_congestion, dep_congestion, block_time_diff, enroute_delay,
dep_delay_rolling_mean, dep_congestion_roll3h, route_hour_delay_mean,
taxi_in_ratio, taxi_out_ratio, arr_congestion_roll3h, month_cos,
pres_delta, wspd_delta, arr_congestion, month_sin, season_wind,
season, distance_partofday, part_of_day_ord, segment_peak_hours,
wind_speed_cat_ord, humidity_cat_ord, prcp_delta, pressure_cat_ord,
distance_bin_ord
):
# Prepare input DataFrame
X_input = pd.DataFrame([{
'dep_delay_x_congestion': dep_delay_x_congestion,
'dep_congestion': dep_congestion,
'block_time_diff': block_time_diff,
'enroute_delay': enroute_delay,
'dep_delay_rolling_mean': dep_delay_rolling_mean,
'dep_congestion_roll3h': dep_congestion_roll3h,
'route_hour_delay_mean': route_hour_delay_mean,
'taxi_in_ratio': taxi_in_ratio,
'taxi_out_ratio': taxi_out_ratio,
'arr_congestion_roll3h': arr_congestion_roll3h,
'month_cos': month_cos,
'pres_delta': pres_delta,
'wspd_delta': wspd_delta,
'arr_congestion': arr_congestion,
'month_sin': month_sin,
'season_wind': season_wind,
'season': season,
'distance_partofday': distance_partofday,
'part_of_day_ord': part_of_day_ord,
'segment_peak_hours': segment_peak_hours,
'wind_speed_cat_ord': wind_speed_cat_ord,
'humidity_cat_ord': humidity_cat_ord,
'prcp_delta': prcp_delta,
'pressure_cat_ord': pressure_cat_ord,
'distance_bin_ord': distance_bin_ord
}])
# Get predictions from each model
preds = {name: models[name].predict(X_input)[0] for name in models}
# Weighted ensemble prediction
final_pred = sum(preds[name] * weights[name] for name in preds)
# Categorize delay
if 20 <= final_pred < 30:
delay_category = "Minimal Delay"
elif 30 <= final_pred <= 60:
delay_category = "Moderate Delay"
elif final_pred > 60:
delay_category = "Excessive Delay"
else:
delay_category = "No Significant Delay"
return preds, final_pred, delay_category
# --- Gradio interface ---
inputs = [
gr.Number(label="Departure Delay × Congestion Index"),
gr.Number(label="Departure Congestion"),
gr.Number(label="Block Time Difference (min)"),
gr.Number(label="En‑Route Delay (min)"),
gr.Number(label="Avg Departure Delay (last 3 flights)"),
gr.Number(label="Departure Congestion (3‑hour rolling)"),
gr.Number(label="Avg Delay for Route & Hour"),
gr.Number(label="Taxi‑In Time ÷ Block Time"),
gr.Number(label="Taxi‑Out Time ÷ Block Time"),
gr.Number(label="Arrival Congestion (3‑hour rolling)"),
gr.Number(label="Month (cosine encoding)"),
gr.Number(label="Pressure Change (Dest − Origin)"),
gr.Number(label="Wind Speed Change (Dest − Origin)"),
gr.Number(label="Arrival Congestion"),
gr.Number(label="Month (sine encoding)"),
gr.Number(label="Seasonal Wind Category"),
gr.Number(label="Season (ordinal)"),
gr.Number(label="Distance × Part of Day"),
gr.Number(label="Part of Day (ordinal)"),
gr.Number(label="Peak Hour Segment Flag"),
gr.Number(label="Wind Speed Category (ordinal)"),
gr.Number(label="Humidity Category (ordinal)"),
gr.Number(label="Precipitation Change (Dest − Origin)"),
gr.Number(label="Pressure Category (ordinal)"),
gr.Number(label="Distance Bin (ordinal)")
]
outputs = [
gr.JSON(label="Model Predictions"),
gr.Number(label="Weighted Ensemble Prediction (minutes)"),
gr.Textbox(label="Delay Category")
]
demo = gr.Interface(
fn=predict,
inputs=inputs,
outputs=outputs,
title="Flight Delay Prediction (Weighted Ensemble)",
description="Enter flight features to get predictions from GradientBoosting, RandomForest, and XGBoost, plus a weighted ensemble result and delay category."
)
if __name__ == "__main__":
demo.launch()