smhs16's picture
Upload 9 files
ac01a7a verified
"""
app.py — Gradio demo for HF Spaces.
Loads the trained model (or a stub) and exposes a live prediction UI.
"""
import gradio as gr
import pandas as pd
import numpy as np
import pickle
import os
from pathlib import Path
# ── Load model (falls back to a stub if not yet trained) ──────────────────────
MODEL_PATH = Path("models/best_model.pkl")
def load_model():
if MODEL_PATH.exists():
with open(MODEL_PATH, "rb") as f:
return pickle.load(f)
return None # stub mode
model = load_model()
CARRIER_MAP = {"AA": 0, "DL": 1, "UA": 2, "WN": 3, "B6": 4, "AS": 5, "F9": 6, "NK": 7}
AIRPORT_STUB = { # small lookup for demo
"ATL": 0, "LAX": 1, "ORD": 2, "DFW": 3, "DEN": 4,
"JFK": 5, "SFO": 6, "SEA": 7, "LAS": 8, "MIA": 9,
}
def predict_delay(
dep_hour, dep_dayofweek, dep_month,
carrier, origin, dest,
crs_elapsed_time, distance,
origin_delay_rate,
is_weekend, is_peak_hour,
):
carrier_enc = CARRIER_MAP.get(carrier, 0)
origin_enc = AIRPORT_STUB.get(origin, 0)
dest_enc = AIRPORT_STUB.get(dest, 1)
features = pd.DataFrame([{
"dep_hour": int(dep_hour),
"dep_dayofweek": int(dep_dayofweek),
"dep_month": int(dep_month),
"carrier_enc": carrier_enc,
"origin_enc": origin_enc,
"dest_enc": dest_enc,
"crs_elapsed_time": float(crs_elapsed_time),
"distance": float(distance),
"origin_delay_rate": float(origin_delay_rate),
"is_weekend": int(is_weekend == "Yes"),
"is_peak_hour": int(is_peak_hour == "Yes"),
}])
if model is not None:
prob = float(model.predict_proba(features)[0, 1])
else:
# Demo stub: simple heuristic
prob = min(1.0, (
origin_delay_rate * 0.5 +
(0.15 if is_peak_hour == "Yes" else 0) +
(0.1 if is_weekend == "Yes" else 0) +
(dep_month in [6, 7, 12]) * 0.1
) + np.random.normal(0, 0.05))
prob = max(0.0, prob)
label = "🔴 LIKELY DELAYED" if prob >= 0.5 else "🟢 LIKELY ON TIME"
confidence = "HIGH" if abs(prob - 0.5) > 0.25 else "MEDIUM" if abs(prob - 0.5) > 0.1 else "LOW"
bar = "█" * int(prob * 20) + "░" * (20 - int(prob * 20))
return (
f"{label}\n\n"
f"Delay probability : {prob:.1%}\n"
f"Confidence : {confidence}\n"
f"[{bar}] {prob:.1%}"
)
# ── UI ────────────────────────────────────────────────────────────────────────
with gr.Blocks(
title="✈️ Flight Delay Predictor",
theme=gr.themes.Base(primary_hue="blue", neutral_hue="slate"),
) as demo:
gr.Markdown("""
# ✈️ Flight Delay Prediction
Predict whether a flight will be **≥ 15 minutes late** using the trained XGBoost model.
> Part of the [Flight Delay ML Platform](https://github.com/YOUR_USERNAME/flight-delay-platform)
""")
with gr.Row():
with gr.Column():
gr.Markdown("### ✈️ Flight Details")
carrier = gr.Dropdown(list(CARRIER_MAP.keys()), value="AA", label="Airline")
origin = gr.Dropdown(list(AIRPORT_STUB.keys()), value="ATL", label="Origin Airport")
dest = gr.Dropdown(list(AIRPORT_STUB.keys()), value="LAX", label="Destination Airport")
distance = gr.Slider(100, 5000, value=1400, step=50, label="Distance (miles)")
crs_elapsed_time = gr.Slider(30, 600, value=185, step=5, label="Scheduled Duration (min)")
with gr.Column():
gr.Markdown("### 🕐 Schedule")
dep_hour = gr.Slider(0, 23, value=8, step=1, label="Departure Hour (0–23)")
dep_dayofweek = gr.Slider(0, 6, value=1, step=1, label="Day of Week (0=Mon, 6=Sun)")
dep_month = gr.Slider(1, 12, value=3, step=1, label="Month")
is_weekend = gr.Radio(["Yes", "No"], value="No", label="Weekend Flight?")
is_peak_hour = gr.Radio(["Yes", "No"], value="Yes", label="Peak Hour? (7–9am / 5–8pm)")
with gr.Column():
gr.Markdown("### 🌦 Airport History")
origin_delay_rate = gr.Slider(
0.0, 1.0, value=0.22, step=0.01,
label="Origin Airport 30-Day Delay Rate"
)
gr.Markdown("### 📊 Prediction")
output = gr.Textbox(label="Result", lines=5, interactive=False)
predict_btn = gr.Button("Predict Delay →", variant="primary")
predict_btn.click(
fn=predict_delay,
inputs=[
dep_hour, dep_dayofweek, dep_month,
carrier, origin, dest,
crs_elapsed_time, distance,
origin_delay_rate, is_weekend, is_peak_hour,
],
outputs=output,
)
gr.Examples(
examples=[
[8, 1, 3, "AA", "ATL", "LAX", 185, 1400, 0.22, "No", "Yes"],
[18, 4, 7, "UA", "ORD", "JFK", 140, 780, 0.38, "No", "Yes"],
[6, 6, 1, "WN", "DEN", "SFO", 95, 950, 0.12, "Yes", "No" ],
[14, 3, 12,"DL", "ATL", "MIA", 75, 660, 0.45, "No", "No" ],
],
inputs=[
dep_hour, dep_dayofweek, dep_month,
carrier, origin, dest,
crs_elapsed_time, distance,
origin_delay_rate, is_weekend, is_peak_hour,
],
)
if __name__ == "__main__":
demo.launch()