mimo1972 commited on
Commit
a0e0a43
Β·
verified Β·
1 Parent(s): 136f7e9

Upload 4 files

Browse files
AirFlights_HistBoost_model.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3307968d2d458adfa9331b73d24b22e44b7a232847cf5cf5ff01245c8ec61524
3
- size 541810
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04058bf9b544483c06567ddd213884f85d552d6ef02c69e94f1bd8b6a820c580
3
+ size 834018
flightprice.py CHANGED
@@ -1,163 +1,230 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
- import joblib
5
- import matplotlib.pyplot as plt
6
- import seaborn as sns
7
- from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
8
-
9
- # Set Page Config
10
- st.set_page_config(page_title="Flight Price Predictor", layout="wide")
11
-
12
- # --- 1. Helper Functions ---
13
- @st.cache_data
14
- def load_data():
15
- """Loads the test data to get unique values for dropdowns and for evaluation."""
16
- x_test = pd.read_parquet('x_test.parquet')
17
- y_test = pd.read_parquet('y_test.parquet')
18
- return x_test, y_test
19
-
20
- @st.cache_resource
21
- def load_model():
22
- """Loads the trained HistGradientBoosting model."""
23
- return joblib.load('AirFlights_HistBoost_model.pkl')
24
-
25
- # Load Data and Model
26
- try:
27
- x_test, y_test = load_data()
28
- model = load_model()
29
- # Ensure target is 1D array
30
- if isinstance(y_test, pd.DataFrame):
31
- y_test_series = y_test.iloc[:, 0]
32
- else:
33
- y_test_series = y_test
34
- except Exception as e:
35
- st.error(f"Error loading files: {e}")
36
- st.stop()
37
-
38
- # --- 2. Sidebar Navigation ---
39
- st.sidebar.title("Navigation")
40
- page = st.sidebar.radio("Go to", ["✈️ Predict Price", "qh Model Evaluation"])
41
-
42
- # --- PAGE 1: PREDICT PRICE ---
43
- if page == "✈️ Predict Price":
44
- st.title("✈️ Flight Price Prediction")
45
- st.markdown("Enter the flight details below to get an estimated price.")
46
-
47
- # Create a form for user input
48
- with st.form("prediction_form"):
49
- col1, col2, col3 = st.columns(3)
50
-
51
- # We extract unique values from x_test to populate dropdowns automatically
52
- # This ensures the inputs match exactly what the model learned
53
-
54
- with col1:
55
- airline = st.selectbox("Airline", sorted(x_test['Airline'].unique()))
56
- source = st.selectbox("Source", sorted(x_test['Source'].unique()))
57
- destination = st.selectbox("Destination", sorted(x_test['Destination'].unique()))
58
-
59
- with col2:
60
- # Categorical Time Features
61
- month = st.selectbox("Month", x_test['Month'].unique())
62
- day = st.selectbox("Day", x_test['Day'].unique()) # e.g. Weekday or Day of Month
63
- dept_quarter = st.selectbox("Departure Time of Day", x_test['Dept_Day_Quarter'].unique())
64
-
65
- with col3:
66
- # Numerical Features
67
- stops = st.number_input("Total Stops", min_value=0, max_value=4, step=1, value=0)
68
- duration = st.number_input("Duration (minutes)", min_value=30, max_value=3000, step=15, value=120)
69
-
70
- submitted = st.form_submit_button("Predict Price")
71
-
72
- if submitted:
73
- # 1. Prepare Input Data
74
- input_data = pd.DataFrame({
75
- 'Airline': [airline],
76
- 'Source': [source],
77
- 'Destination': [destination],
78
- 'Total_Stops': [stops],
79
- 'Duration_minutes': [duration],
80
- 'Day': [day],
81
- 'Month': [month],
82
- 'Dept_Day_Quarter': [dept_quarter]
83
- })
84
-
85
- # Ensure columns are in the exact same order as x_test
86
- input_data = input_data[x_test.columns]
87
-
88
- # 2. Predict (Model returns Log Price)
89
- log_prediction = model.predict(input_data)[0]
90
-
91
- # 3. Inverse Transform (Log -> Real Price)
92
- real_price = np.expm1(log_prediction)
93
-
94
- # 4. Display Result
95
- st.success(f"Estimated Ticket Price: β‚Ή {real_price:,.2f}")
96
-
97
- # Debug info (optional)
98
- with st.expander("See processed input"):
99
- st.write(input_data)
100
-
101
-
102
- # --- PAGE 2: MODEL EVALUATION ---
103
- elif page == "qh Model Evaluation":
104
- st.title("qh Model Performance Report")
105
- st.write("Evaluating the model on `x_test.parquet` and `y_test.parquet`.")
106
-
107
- if st.button("Run Evaluation"):
108
- with st.spinner("Calculating predictions..."):
109
- # 1. Predict on Test Set
110
- y_pred_log = model.predict(x_test)
111
-
112
- # 2. Convert to Real Prices
113
- y_pred_real = np.expm1(y_pred_log)
114
- y_test_real = np.expm1(y_test_series)
115
-
116
- # 3. Metrics
117
- r2 = r2_score(y_test_series, y_pred_log) # R2 on Log scale (Model Metric)
118
- r2_real = r2_score(y_test_real, y_pred_real) # R2 on Real scale (Business Metric)
119
- mae = mean_absolute_error(y_test_real, y_pred_real)
120
- rmse = np.sqrt(mean_squared_error(y_test_real, y_pred_real))
121
-
122
- # --- Display Metrics ---
123
- col1, col2, col3, col4 = st.columns(4)
124
- col1.metric("R2 Score (Log)", f"{r2:.4f}")
125
- col2.metric("R2 Score (Real)", f"{r2_real:.4f}")
126
- col3.metric("MAE (Error)", f"β‚Ή {mae:.0f}")
127
- col4.metric("RMSE (Error)", f"β‚Ή {rmse:.0f}")
128
-
129
- st.markdown("---")
130
-
131
- # --- Graphs ---
132
- tab1, tab2 = st.tabs(["Actual vs Predicted", "Residuals Distribution"])
133
-
134
- with tab1:
135
- st.subheader("Actual Prices vs Predicted Prices")
136
- fig, ax = plt.subplots(figsize=(10, 6))
137
- sns.scatterplot(x=y_test_real, y=y_pred_real, alpha=0.5, color="blue", ax=ax)
138
- # Perfect prediction line
139
- min_val = min(y_test_real.min(), y_pred_real.min())
140
- max_val = max(y_test_real.max(), y_pred_real.max())
141
- ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label="Perfect Prediction")
142
- ax.set_xlabel("Actual Price")
143
- ax.set_ylabel("Predicted Price")
144
- ax.legend()
145
- st.pyplot(fig)
146
-
147
- with tab2:
148
- st.subheader("Residuals (Error) Distribution")
149
- residuals = y_test_real - y_pred_real
150
- fig, ax = plt.subplots(figsize=(10, 6))
151
- sns.histplot(residuals, kde=True, color="purple", ax=ax)
152
- ax.set_xlabel("Error (Actual - Predicted)")
153
- ax.set_title("Are the errors centered around 0?")
154
- st.pyplot(fig)
155
-
156
- # --- Data Table ---
157
- st.markdown("---")
158
- st.subheader("Detailed Test Data & Predictions")
159
- results_df = x_test.copy()
160
- results_df['Actual_Price'] = y_test_real
161
- results_df['Predicted_Price'] = y_pred_real
162
- results_df['Difference'] = results_df['Actual_Price'] - results_df['Predicted_Price']
163
- st.dataframe(results_df.head(100))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import joblib
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
8
+
9
+ # Set Page Config
10
+ st.set_page_config(page_title="Flight Price System", layout="wide", page_icon="✈️")
11
+
12
+ # --- 1. Helper Functions ---
13
+ @st.cache_data
14
+ def load_data():
15
+ try:
16
+ x_train = pd.read_parquet('x_train.parquet')
17
+ x_test = pd.read_parquet('x_test.parquet')
18
+ y_test = pd.read_parquet('y_test.parquet')
19
+
20
+ # CLEANING: Strip whitespace
21
+ for col in ['Airline', 'Source', 'Destination', 'Route']:
22
+ if col in x_train.columns:
23
+ x_train[col] = x_train[col].astype(str).str.strip()
24
+
25
+ return x_train, x_test, y_test
26
+ except Exception as e:
27
+ st.error(f"Error loading data files: {e}")
28
+ return None, None, None
29
+
30
+ @st.cache_resource
31
+ def load_model():
32
+ try:
33
+ return joblib.load('AirFlights_HistBoost_model.pkl')
34
+ except Exception as e:
35
+ st.error(f"Error loading model: {e}")
36
+ return None
37
+
38
+ def get_day_name(day_num, month_name):
39
+ try:
40
+ date_str = f"{int(day_num)}-{month_name}-2019"
41
+ return pd.to_datetime(date_str, format="%d-%B-%Y").day_name()
42
+ except:
43
+ return "Monday"
44
+
45
+ def get_day_quarter(hour):
46
+ if 5 <= hour < 12:
47
+ return 'Morning'
48
+ elif 12 <= hour < 17:
49
+ return 'Afternoon'
50
+ elif 17 <= hour < 21:
51
+ return 'Evening'
52
+ else:
53
+ return 'Night'
54
+
55
+ # --- 2. EXHAUSTIVE AIRPORT CODE MAPPING ---
56
+ CITY_TO_CODE = {
57
+ 'Banglore': 'BLR',
58
+ 'Bangalore': 'BLR',
59
+ 'Delhi': 'DEL',
60
+ 'New Delhi': 'DEL',
61
+ 'Kolkata': 'CCU',
62
+ 'Calcutta': 'CCU',
63
+ 'Hyderabad': 'HYD',
64
+ 'Chennai': 'MAA',
65
+ 'Madras': 'MAA',
66
+ 'Mumbai': 'BOM',
67
+ 'Bombay': 'BOM',
68
+ 'Cochin': 'COK',
69
+ 'Kochi': 'COK',
70
+ 'Pune': 'PNQ',
71
+ 'Goa': 'GOI',
72
+ 'Jaipur': 'JAI',
73
+ 'Lucknow': 'LKO',
74
+ 'Patna': 'PAT',
75
+ 'Varanasi': 'VNS',
76
+ 'Bhubaneswar': 'BBI',
77
+ 'Nagpur': 'NAG',
78
+ 'Trivandrum': 'TRV'
79
+ }
80
+
81
+ def get_code(city):
82
+ clean_city = city.strip()
83
+ return CITY_TO_CODE.get(clean_city, clean_city[:3].upper())
84
+
85
+ def is_route_valid(route_str, source_code, dest_code):
86
+ if pd.isna(route_str) or route_str == 'nan':
87
+ return False
88
+ route_upper = route_str.upper()
89
+ parts = route_upper.replace("β†’", " ").replace("->", " ").split()
90
+ if not parts:
91
+ return False
92
+ first_stop = parts[0]
93
+ return (first_stop == source_code) and (dest_code in parts)
94
+
95
+ # Load Data
96
+ x_train, x_test, y_test = load_data()
97
+ model = load_model()
98
+
99
+ if x_train is None or model is None:
100
+ st.stop()
101
+
102
+ if isinstance(y_test, pd.DataFrame):
103
+ y_test_series = y_test.iloc[:, 0]
104
+ else:
105
+ y_test_series = y_test
106
+
107
+ # Build Lookup
108
+ route_lookup = {}
109
+ if 'Route' in x_train.columns and 'Total_Stops' in x_train.columns:
110
+ temp = x_train[['Route', 'Total_Stops']].drop_duplicates(subset=['Route'])
111
+ route_lookup = temp.set_index('Route')['Total_Stops'].to_dict()
112
+
113
+ # --- 3. App Layout ---
114
+ st.sidebar.title("Navigation")
115
+ page = st.sidebar.radio("Go to", ["πŸ’° Price Prediction", "πŸ“Š Model Evaluation"])
116
+
117
+ if page == "πŸ’° Price Prediction":
118
+ st.title("✈️ Flight Price Prediction")
119
+ st.markdown("### Enter Flight Details")
120
+
121
+ # REMOVED st.form HERE so inputs update instantly!
122
+
123
+ c1, c2 = st.columns(2)
124
+ with c1:
125
+ st.subheader("Flight Info")
126
+
127
+ # Source (Updates instantly now)
128
+ source = st.selectbox("Source", sorted(x_train['Source'].unique()))
129
+ src_code = get_code(source)
130
+ st.success(f"πŸ›« Source Code: **{src_code}**")
131
+
132
+ # Destination (Updates instantly now)
133
+ destination = st.selectbox("Destination", sorted(x_train['Destination'].unique()))
134
+ dest_code = get_code(destination)
135
+ st.error(f"πŸ›¬ Destination Code: **{dest_code}**")
136
+
137
+ airline = st.selectbox("Airline", sorted(x_train['Airline'].unique()))
138
+
139
+ with c2:
140
+ st.subheader("Date & Time")
141
+ if 'Month' in x_train.columns:
142
+ months = sorted(x_train['Month'].unique())
143
+ else:
144
+ months = ['March', 'April', 'May', 'June', 'September', 'December']
145
+ month = st.selectbox("Month", months)
146
+ day_number = st.number_input("Day Number", 1, 31, 1)
147
+ dept_hour = st.number_input("Departure Hour", 0, 23, 10)
148
+
149
+ st.markdown("---")
150
+ st.subheader("Route Selection")
151
+
152
+ selected_route = None
153
+ stops_val = 0
154
+
155
+ if 'Route' in x_train.columns:
156
+ # 1. Get all routes
157
+ all_routes = sorted(x_train['Route'].unique().astype(str))
158
+
159
+ # 2. FILTER: Show only routes starting with the exact Source Code
160
+ # Now this runs immediately when you change 'Source' above
161
+ filtered_routes = []
162
+ for r in all_routes:
163
+ parts = r.upper().replace("β†’", " ").replace("->", " ").split()
164
+ if parts and parts[0] == src_code:
165
+ filtered_routes.append(r)
166
+
167
+ if filtered_routes:
168
+ selected_route_raw = st.selectbox("Select Route", filtered_routes)
169
+
170
+ # 3. VALIDATE
171
+ if is_route_valid(selected_route_raw, src_code, dest_code):
172
+ selected_route = selected_route_raw
173
+ stops_val = route_lookup.get(selected_route, 0)
174
+ st.metric("Total Stops", stops_val)
175
+ st.success("βœ… Valid Route")
176
+ else:
177
+ st.warning(f"⚠️ **{selected_route_raw}** starts at **{src_code}** but does not reach **{dest_code}**. Please check your destination.")
178
+ selected_route = None
179
+ else:
180
+ st.error(f"No routes found starting with code **{src_code}**.")
181
+ else:
182
+ st.error("Route column missing.")
183
+
184
+ # We only use a button for the final prediction calculation
185
+ if st.button("Predict Price", type="primary"):
186
+ if selected_route:
187
+ # Prepare Input
188
+ day_name = get_day_name(day_number, month)
189
+ quarter = get_day_quarter(dept_hour)
190
+
191
+ input_df = pd.DataFrame({
192
+ 'Airline': [airline], 'Source': [source], 'Destination': [destination],
193
+ 'Month': [month], 'Route': [selected_route],
194
+ 'Day_number': [day_number], 'Dept_hour': [dept_hour],
195
+ 'Day': [day_name], 'Dept_Day_Quarter': [quarter],
196
+ 'Total_Stops': [stops_val]
197
+ })
198
+
199
+ # Align Cols
200
+ final_input = pd.DataFrame(columns=x_train.columns)
201
+ for col in x_train.columns:
202
+ final_input.loc[0, col] = input_df.iloc[0].get(col, 0)
203
+
204
+ # Types
205
+ for col in final_input.columns:
206
+ if x_train[col].dtype == 'object':
207
+ final_input[col] = final_input[col].astype(str)
208
+ else:
209
+ final_input[col] = pd.to_numeric(final_input[col])
210
+
211
+ try:
212
+ pred = model.predict(final_input)[0]
213
+ st.success(f"### Estimated Price: β‚Ή {np.expm1(pred):,.2f}")
214
+ except Exception as e:
215
+ st.error(f"Error: {e}")
216
+ else:
217
+ st.error("Please select a valid route.")
218
+
219
+ elif page == "πŸ“Š Model Evaluation":
220
+ st.title("Model Evaluation")
221
+ if st.button("Evaluate"):
222
+ with st.spinner("Running..."):
223
+ y_pred = model.predict(x_test)
224
+ r2 = r2_score(y_test_series, y_pred)
225
+ st.metric("R2 Score", f"{r2:.4f}")
226
+
227
+ fig, ax = plt.subplots()
228
+ sns.scatterplot(x=np.expm1(y_test_series), y=np.expm1(y_pred), ax=ax)
229
+ ax.plot([0, 80000], [0, 80000], 'r--')
230
+ st.pyplot(fig)
x_test.parquet CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:14b2378d0cc0c08968a3ca37404afcb40de80a12b698a41e4cb128c0037aa2e6
3
- size 25918
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:008327ee502389d2aff543afc4f4a5749f7147856f9f15fd04e4d8ef744bf509
3
+ size 7260
y_test.parquet CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a8ed09322ef92c2d46b031a9f18c0202b5bef664652be589e4c5da5414c8c0f
3
- size 22385
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce163df1a6d494b9d502d2caed1e8b0cd424e31ff84371bfc21853f09f2343d3
3
+ size 3427