File size: 5,349 Bytes
8210755
 
 
 
c98bf6f
8210755
 
c98bf6f
8210755
 
 
 
 
49e6cce
8210755
 
 
 
 
 
 
 
 
 
 
49e6cce
8210755
c98bf6f
 
 
 
8210755
 
 
c98bf6f
49e6cce
c98bf6f
 
 
 
 
 
 
 
49e6cce
c98bf6f
 
49e6cce
8210755
 
 
49e6cce
c98bf6f
49e6cce
8210755
 
49e6cce
c98bf6f
 
 
 
 
49e6cce
8210755
 
49e6cce
8210755
 
49e6cce
8210755
 
 
49e6cce
c98bf6f
 
8210755
49e6cce
c98bf6f
 
 
 
 
 
 
 
 
 
 
 
 
8210755
 
c98bf6f
 
 
 
 
 
 
8210755
 
 
c98bf6f
8210755
 
c98bf6f
8210755
c98bf6f
 
8210755
c98bf6f
8210755
 
c98bf6f
 
 
8210755
c98bf6f
 
8210755
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import pandas as pd
import joblib
from huggingface_hub import hf_hub_download
from datetime import datetime

# --- 1. Download and Load Models ---
print("Downloading models from Hugging Face Hub...")
REPO_ID = "matanzig/flight-price-prediction"

reg_model = joblib.load(hf_hub_download(repo_id=REPO_ID, filename="flight_price_rf_model.pkl"))
cls_model = joblib.load(hf_hub_download(repo_id=REPO_ID, filename="flight_price_classifier_rf.pkl"))


COLUMNS = [
    'startingAirport', 'destinationAirport', 'isBasicEconomy', 'isRefundable', 
    'isNonStop', 'seatsRemaining', 'totalTravelDistance', 'month', 
    'days_until_flight', 'travelDuration_mins', 'primary_airline', 
    'primary_cabin', 'departure_hour', 'day_of_week_Monday', 
    'day_of_week_Saturday', 'day_of_week_Sunday', 'day_of_week_Thursday', 
    'day_of_week_Tuesday', 'day_of_week_Wednesday', 'is_weekend', 
    'days_until_flight_squared', 'cluster_group_1', 'cluster_group_2', 
    'cluster_group_3', 'cluster_group_1', 'cluster_group_2', 'cluster_group_3'
]


AIRLINE_MAPPING = {
    "Alaska Airlines": 0, "American Airlines": 1, "Boutique Air": 2, "Cape Air": 3,
    "Contour Airlines": 4, "Delta": 5, "Frontier Airlines": 6, "Hawaiian Airlines": 7,
    "JetBlue Airways": 8, "Key Lime Air": 9, "Southern Airways Express": 10,
    "Spirit Airlines": 11, "Sun Country Airlines": 12, "United": 13
}

# --- 2. Prediction Engine ---
def predict_flight_price(flight_date, distance, duration, days_until, seats, airline_name, seen_price, is_nonstop, is_basic_economy):
  
    try:
        dt = pd.to_datetime(flight_date, format="%Y-%m-%d")
        month_val = dt.month
        day_name = dt.day_name()
        is_weekend = 1 if day_name in ['Saturday', 'Sunday'] else 0
    except ValueError:
        raise gr.Error("❌ Invalid Date Format! Please use exactly YYYY-MM-DD (e.g., 2026-07-15).")

    
    scaled_month = (month_val - 7.0) / 2.0
    
   
    scaled_days = (days_until - 30) / 15.0
    scaled_days_sq = scaled_days ** 2
    
    
    airline_id = AIRLINE_MAPPING.get(airline_name, 5) 

    
    row_data = [
        4, 5, 
        int(is_basic_economy), 0, int(is_nonstop), int(seats), float(distance),
        float(scaled_month), float(scaled_days), float(duration), int(airline_id), 1, -0.32,
        day_name == 'Monday', day_name == 'Saturday', day_name == 'Sunday',
        day_name == 'Thursday', day_name == 'Tuesday', day_name == 'Wednesday',
        is_weekend, float(scaled_days_sq),
        False, False, False, False, False, False 
    ]
    
    
    input_df = pd.DataFrame([row_data], columns=COLUMNS)
    
    
    reg_prediction = reg_model.predict(input_df)[0]
    cls_prediction = cls_model.predict(input_df)[0]
    
    
    tier_mapping = {0: "Budget Expected 🟢", 1: "Standard Expected 🟡", 2: "Premium Expected 🔴"}
    expected_tier = tier_mapping.get(cls_prediction, "Unknown")
    
    
    if seen_price > 0:
        diff = seen_price - reg_prediction
        if diff < -25:
            deal_analysis = f"🔥 Amazing Deal! This ticket is ${abs(diff):.0f} CHEAPER than the AI average. Book it now!"
        elif diff > 25:
            deal_analysis = f"⚠️ Overpriced! This ticket is ${diff:.0f} MORE EXPENSIVE than it should be. Wait or find another flight."
        else:
            deal_analysis = f"⚖️ Fair Market Price. The price you found perfectly matches our algorithm's baseline."
    else:
        deal_analysis = "Enter a 'Seen Price' above to get an instant deal analysis."
        
    return f"${reg_prediction:.2f}", expected_tier, deal_analysis


# --- 3. Gradio Interface (UI) ---
DESCRIPTION = """
**Welcome to the US Flight Price Predictor AI!** 
Enter your flight details below to get an AI-powered baseline fare and deal analysis.

⚠️ **Model Limitations:** This algorithm was trained exclusively on Expedia data spanning from **April to October 2022**. It is highly optimized for summer and early fall travel dynamics, but does *not* account for major winter holiday price surges (e.g., Thanksgiving, Christmas) or extreme macroeconomic inflation events beyond that window.
"""

interface = gr.Interface(
    fn=predict_flight_price,
    inputs=[
        gr.Textbox(value="2026-07-15", label="Flight Date (YYYY-MM-DD)", placeholder="e.g., 2026-08-01"),
        gr.Slider(minimum=100, maximum=3000, step=50, value=1300, label="Travel Distance (Miles)"),
        gr.Slider(minimum=60, maximum=800, step=10, value=400, label="Travel Duration (Minutes)"),
        gr.Slider(minimum=0, maximum=90, step=1, value=45, label="How many days in advance are you booking?"),
        gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Seats Remaining"),
        gr.Dropdown(choices=list(AIRLINE_MAPPING.keys()), value="Delta", label="Airline"),
        gr.Number(value=0, label="Price you found online ($) - Optional"), 
        gr.Checkbox(label="Is Non-Stop (Direct Flight)?"),
        gr.Checkbox(label="Is Basic Economy?")
    ],
    outputs=[
        gr.Textbox(label="🤖 AI Predicted Fair Price"),
        gr.Textbox(label="📊 Expected Market Tier (Classifier)"),
        gr.Textbox(label="💡 Personal Deal Analysis")
    ],
    title="✈️ Smart Flight Price Predictor",
    description=DESCRIPTION,
    theme="default"
)

if __name__ == "__main__":
    interface.launch()