mimo1972 commited on
Commit
7973875
·
verified ·
1 Parent(s): c492db2

Update flightprice.py

Browse files
Files changed (1) hide show
  1. flightprice.py +219 -229
flightprice.py CHANGED
@@ -1,230 +1,220 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
- import joblib
5
- import matplotlib.pyplot as plt
6
- import seaborn as sns
7
- from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
8
-
9
- # Set Page Config
10
- st.set_page_config(page_title="Flight Price System", layout="wide", page_icon="✈️")
11
-
12
- # --- 1. Helper Functions ---
13
- @st.cache_data
14
- def load_data():
15
- try:
16
- x_train = pd.read_parquet('x_train.parquet')
17
- x_test = pd.read_parquet('x_test.parquet')
18
- y_test = pd.read_parquet('y_test.parquet')
19
-
20
- # CLEANING: Strip whitespace
21
- for col in ['Airline', 'Source', 'Destination', 'Route']:
22
- if col in x_train.columns:
23
- x_train[col] = x_train[col].astype(str).str.strip()
24
-
25
- return x_train, x_test, y_test
26
- except Exception as e:
27
- st.error(f"Error loading data files: {e}")
28
- return None, None, None
29
-
30
- @st.cache_resource
31
- def load_model():
32
- try:
33
- return joblib.load('AirFlights_HistBoost_model.pkl')
34
- except Exception as e:
35
- st.error(f"Error loading model: {e}")
36
- return None
37
-
38
- def get_day_name(day_num, month_name):
39
- try:
40
- date_str = f"{int(day_num)}-{month_name}-2019"
41
- return pd.to_datetime(date_str, format="%d-%B-%Y").day_name()
42
- except:
43
- return "Monday"
44
-
45
- def get_day_quarter(hour):
46
- if 5 <= hour < 12:
47
- return 'Morning'
48
- elif 12 <= hour < 17:
49
- return 'Afternoon'
50
- elif 17 <= hour < 21:
51
- return 'Evening'
52
- else:
53
- return 'Night'
54
-
55
- # --- 2. EXHAUSTIVE AIRPORT CODE MAPPING ---
56
- CITY_TO_CODE = {
57
- 'Banglore': 'BLR',
58
- 'Bangalore': 'BLR',
59
- 'Delhi': 'DEL',
60
- 'New Delhi': 'DEL',
61
- 'Kolkata': 'CCU',
62
- 'Calcutta': 'CCU',
63
- 'Hyderabad': 'HYD',
64
- 'Chennai': 'MAA',
65
- 'Madras': 'MAA',
66
- 'Mumbai': 'BOM',
67
- 'Bombay': 'BOM',
68
- 'Cochin': 'COK',
69
- 'Kochi': 'COK',
70
- 'Pune': 'PNQ',
71
- 'Goa': 'GOI',
72
- 'Jaipur': 'JAI',
73
- 'Lucknow': 'LKO',
74
- 'Patna': 'PAT',
75
- 'Varanasi': 'VNS',
76
- 'Bhubaneswar': 'BBI',
77
- 'Nagpur': 'NAG',
78
- 'Trivandrum': 'TRV'
79
- }
80
-
81
- def get_code(city):
82
- clean_city = city.strip()
83
- return CITY_TO_CODE.get(clean_city, clean_city[:3].upper())
84
-
85
- def is_route_valid(route_str, source_code, dest_code):
86
- if pd.isna(route_str) or route_str == 'nan':
87
- return False
88
- route_upper = route_str.upper()
89
- parts = route_upper.replace("→", " ").replace("->", " ").split()
90
- if not parts:
91
- return False
92
- first_stop = parts[0]
93
- return (first_stop == source_code) and (dest_code in parts)
94
-
95
- # Load Data
96
- x_train, x_test, y_test = load_data()
97
- model = load_model()
98
-
99
- if x_train is None or model is None:
100
- st.stop()
101
-
102
- if isinstance(y_test, pd.DataFrame):
103
- y_test_series = y_test.iloc[:, 0]
104
- else:
105
- y_test_series = y_test
106
-
107
- # Build Lookup
108
- route_lookup = {}
109
- if 'Route' in x_train.columns and 'Total_Stops' in x_train.columns:
110
- temp = x_train[['Route', 'Total_Stops']].drop_duplicates(subset=['Route'])
111
- route_lookup = temp.set_index('Route')['Total_Stops'].to_dict()
112
-
113
- # --- 3. App Layout ---
114
- st.sidebar.title("Navigation")
115
- page = st.sidebar.radio("Go to", ["💰 Price Prediction", "📊 Model Evaluation"])
116
-
117
- if page == "💰 Price Prediction":
118
- st.title("✈️ Flight Price Prediction")
119
- st.markdown("### Enter Flight Details")
120
-
121
- # REMOVED st.form HERE so inputs update instantly!
122
-
123
- c1, c2 = st.columns(2)
124
- with c1:
125
- st.subheader("Flight Info")
126
-
127
- # Source (Updates instantly now)
128
- source = st.selectbox("Source", sorted(x_train['Source'].unique()))
129
- src_code = get_code(source)
130
- st.success(f"🛫 Source Code: **{src_code}**")
131
-
132
- # Destination (Updates instantly now)
133
- destination = st.selectbox("Destination", sorted(x_train['Destination'].unique()))
134
- dest_code = get_code(destination)
135
- st.error(f"🛬 Destination Code: **{dest_code}**")
136
-
137
- airline = st.selectbox("Airline", sorted(x_train['Airline'].unique()))
138
-
139
- with c2:
140
- st.subheader("Date & Time")
141
- if 'Month' in x_train.columns:
142
- months = sorted(x_train['Month'].unique())
143
- else:
144
- months = ['March', 'April', 'May', 'June', 'September', 'December']
145
- month = st.selectbox("Month", months)
146
- day_number = st.number_input("Day Number", 1, 31, 1)
147
- dept_hour = st.number_input("Departure Hour", 0, 23, 10)
148
-
149
- st.markdown("---")
150
- st.subheader("Route Selection")
151
-
152
- selected_route = None
153
- stops_val = 0
154
-
155
- if 'Route' in x_train.columns:
156
- # 1. Get all routes
157
- all_routes = sorted(x_train['Route'].unique().astype(str))
158
-
159
- # 2. FILTER: Show only routes starting with the exact Source Code
160
- # Now this runs immediately when you change 'Source' above
161
- filtered_routes = []
162
- for r in all_routes:
163
- parts = r.upper().replace("→", " ").replace("->", " ").split()
164
- if parts and parts[0] == src_code:
165
- filtered_routes.append(r)
166
-
167
- if filtered_routes:
168
- selected_route_raw = st.selectbox("Select Route", filtered_routes)
169
-
170
- # 3. VALIDATE
171
- if is_route_valid(selected_route_raw, src_code, dest_code):
172
- selected_route = selected_route_raw
173
- stops_val = route_lookup.get(selected_route, 0)
174
- st.metric("Total Stops", stops_val)
175
- st.success("✅ Valid Route")
176
- else:
177
- st.warning(f"⚠️ **{selected_route_raw}** starts at **{src_code}** but does not reach **{dest_code}**. Please check your destination.")
178
- selected_route = None
179
- else:
180
- st.error(f"No routes found starting with code **{src_code}**.")
181
- else:
182
- st.error("Route column missing.")
183
-
184
- # We only use a button for the final prediction calculation
185
- if st.button("Predict Price", type="primary"):
186
- if selected_route:
187
- # Prepare Input
188
- day_name = get_day_name(day_number, month)
189
- quarter = get_day_quarter(dept_hour)
190
-
191
- input_df = pd.DataFrame({
192
- 'Airline': [airline], 'Source': [source], 'Destination': [destination],
193
- 'Month': [month], 'Route': [selected_route],
194
- 'Day_number': [day_number], 'Dept_hour': [dept_hour],
195
- 'Day': [day_name], 'Dept_Day_Quarter': [quarter],
196
- 'Total_Stops': [stops_val]
197
- })
198
-
199
- # Align Cols
200
- final_input = pd.DataFrame(columns=x_train.columns)
201
- for col in x_train.columns:
202
- final_input.loc[0, col] = input_df.iloc[0].get(col, 0)
203
-
204
- # Types
205
- for col in final_input.columns:
206
- if x_train[col].dtype == 'object':
207
- final_input[col] = final_input[col].astype(str)
208
- else:
209
- final_input[col] = pd.to_numeric(final_input[col])
210
-
211
- try:
212
- pred = model.predict(final_input)[0]
213
- st.success(f"### Estimated Price: ₹ {np.expm1(pred):,.2f}")
214
- except Exception as e:
215
- st.error(f"Error: {e}")
216
- else:
217
- st.error("Please select a valid route.")
218
-
219
- elif page == "📊 Model Evaluation":
220
- st.title("Model Evaluation")
221
- if st.button("Evaluate"):
222
- with st.spinner("Running..."):
223
- y_pred = model.predict(x_test)
224
- r2 = r2_score(y_test_series, y_pred)
225
- st.metric("R2 Score", f"{r2:.4f}")
226
-
227
- fig, ax = plt.subplots()
228
- sns.scatterplot(x=np.expm1(y_test_series), y=np.expm1(y_pred), ax=ax)
229
- ax.plot([0, 80000], [0, 80000], 'r--')
230
  st.pyplot(fig)
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import joblib
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
8
+
9
+ # Set Page Config
10
+ st.set_page_config(page_title="Flight Price System", layout="wide", page_icon="✈️")
11
+
12
+ # --- 1. Helper Functions ---
13
+ @st.cache_data
14
+ def load_data():
15
+ try:
16
+ x_train = pd.read_parquet('x_train.parquet')
17
+ x_test = pd.read_parquet('x_test.parquet')
18
+ y_test = pd.read_parquet('y_test.parquet')
19
+
20
+ # CLEANING: Strip whitespace
21
+ for col in ['Airline', 'Source', 'Destination', 'Route']:
22
+ if col in x_train.columns:
23
+ x_train[col] = x_train[col].astype(str).str.strip()
24
+
25
+ return x_train, x_test, y_test
26
+ except Exception as e:
27
+ st.error(f"Error loading data files: {e}")
28
+ return None, None, None
29
+
30
+ @st.cache_resource
31
+ def load_model():
32
+ try:
33
+ return joblib.load('AirFlights_HistBoost_model.pkl')
34
+ except Exception as e:
35
+ st.error(f"Error loading model: {e}")
36
+ return None
37
+
38
+ def get_day_name(day_num, month_name):
39
+ try:
40
+ date_str = f"{int(day_num)}-{month_name}-2019"
41
+ return pd.to_datetime(date_str, format="%d-%B-%Y").day_name()
42
+ except:
43
+ return "Monday"
44
+
45
+ def get_day_quarter(hour):
46
+ if 5 <= hour < 12:
47
+ return 'Morning'
48
+ elif 12 <= hour < 17:
49
+ return 'Afternoon'
50
+ elif 17 <= hour < 21:
51
+ return 'Evening'
52
+ else:
53
+ return 'Night'
54
+
55
+ # --- 2. MAPPING FOR VALIDATION ONLY ---
56
+ # We use this to help check if the Destination is in the Route string
57
+ CITY_TO_CODE = {
58
+ 'Banglore': 'BLR', 'Bangalore': 'BLR',
59
+ 'Delhi': 'DEL', 'New Delhi': 'DEL',
60
+ 'Kolkata': 'CCU', 'Calcutta': 'CCU',
61
+ 'Hyderabad': 'HYD', 'Chennai': 'MAA', 'Madras': 'MAA',
62
+ 'Mumbai': 'BOM', 'Bombay': 'BOM',
63
+ 'Cochin': 'COK', 'Kochi': 'COK',
64
+ 'Pune': 'PNQ', 'Goa': 'GOI', 'Jaipur': 'JAI',
65
+ 'Lucknow': 'LKO', 'Patna': 'PAT', 'Varanasi': 'VNS'
66
+ }
67
+
68
+ def get_code(city):
69
+ return CITY_TO_CODE.get(city.strip(), city[:3].upper())
70
+
71
+ def is_route_valid(route_str, destination):
72
+ """
73
+ Checks if the Destination (Name or Code) appears in the Route string.
74
+ """
75
+ if pd.isna(route_str) or route_str == 'nan':
76
+ return False
77
+
78
+ route_upper = route_str.upper()
79
+ dest_name_upper = destination.upper()
80
+ dest_code = get_code(destination)
81
+
82
+ # Check if code (e.g., 'COK') or name (e.g., 'COCHIN') is in string
83
+ return (dest_code in route_upper) or (dest_name_upper in route_upper)
84
+
85
+ # Load Data
86
+ x_train, x_test, y_test = load_data()
87
+ model = load_model()
88
+
89
+ if x_train is None or model is None:
90
+ st.stop()
91
+
92
+ if isinstance(y_test, pd.DataFrame):
93
+ y_test_series = y_test.iloc[:, 0]
94
+ else:
95
+ y_test_series = y_test
96
+
97
+ # Build Lookup for Stops
98
+ route_lookup = {}
99
+ if 'Route' in x_train.columns and 'Total_Stops' in x_train.columns:
100
+ temp = x_train[['Route', 'Total_Stops']].drop_duplicates(subset=['Route'])
101
+ route_lookup = temp.set_index('Route')['Total_Stops'].to_dict()
102
+
103
+ # --- 3. App Layout ---
104
+ st.sidebar.title("Navigation")
105
+ page = st.sidebar.radio("Go to", ["💰 Price Prediction", "📊 Model Evaluation"])
106
+
107
+ if page == "💰 Price Prediction":
108
+ st.title("✈️ Flight Price Prediction")
109
+ st.markdown("### Enter Flight Details")
110
+
111
+ c1, c2 = st.columns(2)
112
+ with c1:
113
+ st.subheader("Flight Info")
114
+
115
+ # Source (Updates instantly)
116
+ source = st.selectbox("Source", sorted(x_train['Source'].unique()))
117
+
118
+ # Destination (Updates instantly)
119
+ destination = st.selectbox("Destination", sorted(x_train['Destination'].unique()))
120
+
121
+ airline = st.selectbox("Airline", sorted(x_train['Airline'].unique()))
122
+
123
+ with c2:
124
+ st.subheader("Date & Time")
125
+ if 'Month' in x_train.columns:
126
+ months = sorted(x_train['Month'].unique())
127
+ else:
128
+ months = ['March', 'April', 'May', 'June', 'September', 'December']
129
+ month = st.selectbox("Month", months)
130
+ day_number = st.number_input("Day Number", 1, 31, 1)
131
+ dept_hour = st.number_input("Departure Hour", 0, 23, 10)
132
+
133
+ st.markdown("---")
134
+ st.subheader("Route Selection")
135
+
136
+ selected_route = None
137
+ stops_val = 0
138
+
139
+ if 'Route' in x_train.columns:
140
+ # --- 1. SMART FILTERING ---
141
+ # Instead of guessing codes, we ask the DataFrame:
142
+ # "Give me all routes that exist for this specific Source"
143
+ valid_source_routes = x_train[x_train['Source'] == source]['Route'].unique()
144
+ valid_source_routes = sorted(valid_source_routes.astype(str))
145
+
146
+ if len(valid_source_routes) > 0:
147
+ selected_route_raw = st.selectbox("Select Route", valid_source_routes)
148
+
149
+ # --- 2. VALIDATION ---
150
+ # Check if the route actually contains the destination
151
+ if is_route_valid(selected_route_raw, destination):
152
+ selected_route = selected_route_raw
153
+ stops_val = route_lookup.get(selected_route, 0)
154
+ st.metric("Total Stops", stops_val)
155
+ st.success("✅ Valid Route")
156
+ else:
157
+ # Warning if user picks a route that doesn't go to their destination
158
+ st.warning(f"⚠️ **{selected_route_raw}** does not seem to contain **{destination}** ({get_code(destination)}). Please confirm if this is correct.")
159
+ # We allow them to proceed but with a warning, or you can block it:
160
+ # selected_route = None
161
+ selected_route = selected_route_raw
162
+ stops_val = route_lookup.get(selected_route, 0)
163
+
164
+ else:
165
+ st.error(f"No routes found in database for Source: {source}")
166
+ else:
167
+ st.error("Route column missing.")
168
+
169
+ st.markdown("<br>", unsafe_allow_html=True)
170
+
171
+ if st.button("Predict Price", type="primary"):
172
+ if selected_route:
173
+ # Prepare Input
174
+ day_name = get_day_name(day_number, month)
175
+ quarter = get_day_quarter(dept_hour)
176
+
177
+ input_df = pd.DataFrame({
178
+ 'Airline': [airline], 'Source': [source], 'Destination': [destination],
179
+ 'Month': [month], 'Route': [selected_route],
180
+ 'Day_number': [day_number], 'Dept_hour': [dept_hour],
181
+ 'Day': [day_name], 'Dept_Day_Quarter': [quarter],
182
+ 'Total_Stops': [stops_val]
183
+ })
184
+
185
+ # Align Cols
186
+ final_input = pd.DataFrame(columns=x_train.columns)
187
+ for col in x_train.columns:
188
+ final_input.loc[0, col] = input_df.iloc[0].get(col, 0)
189
+
190
+ # Types
191
+ for col in final_input.columns:
192
+ if x_train[col].dtype == 'object':
193
+ final_input[col] = final_input[col].astype(str)
194
+ else:
195
+ final_input[col] = pd.to_numeric(final_input[col])
196
+
197
+ try:
198
+ pred = model.predict(final_input)[0]
199
+ st.success(f"### Estimated Price: ₹ {np.expm1(pred):,.2f}")
200
+ except Exception as e:
201
+ st.error(f"Error: {e}")
202
+ else:
203
+ st.error("Please select a valid route.")
204
+
205
+ elif page == "📊 Model Evaluation":
206
+ st.title("Model Evaluation")
207
+ if st.button("Evaluate"):
208
+ with st.spinner("Running..."):
209
+ y_pred = model.predict(x_test)
210
+ r2 = r2_score(y_test_series, y_pred)
211
+ st.metric("R2 Score", f"{r2:.4f}")
212
+
213
+ fig, ax = plt.subplots(figsize=(10, 6))
214
+ sns.scatterplot(x=np.expm1(y_test_series), y=np.expm1(y_pred), ax=ax, alpha=0.6)
215
+ min_val = min(np.expm1(y_test_series).min(), np.expm1(y_pred).min())
216
+ max_val = max(np.expm1(y_test_series).max(), np.expm1(y_pred).max())
217
+ ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2)
218
+ ax.set_xlabel("Actual Price")
219
+ ax.set_ylabel("Predicted Price")
 
 
 
 
 
 
 
 
 
 
220
  st.pyplot(fig)