Arsive2 commited on
Commit
57787e1
·
1 Parent(s): c16c736

updated hf-code

Browse files
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.csv filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
DockerFile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ COPY app.py .
7
+ COPY prediction.py .
8
+
9
+ RUN mkdir -p data
10
+
11
+ COPY data/ ./data/
12
+
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ EXPOSE 7860
16
+
17
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README copy.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Flight Savvy Hf
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: gray
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ short_description: Best time to buy flight
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import traceback
3
+ from typing import Optional, Union
4
+
5
+ import uvicorn
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from huggingface_hub import hf_hub_download
9
+ from pydantic import BaseModel
10
+
11
+ # Import the prediction function
12
+ from prediction import predict_best_time_to_buy_ticket
13
+
14
+ # Create FastAPI instance
15
+ app = FastAPI(title="FlightSavvy API",
16
+ description="API for predicting the best time to buy flight tickets",
17
+ version="1.0.0")
18
+
19
+ # Add CORS middleware
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"], # Allows all origins
23
+ allow_credentials=True,
24
+ allow_methods=["*"], # Allows all methods
25
+ allow_headers=["*"], # Allows all headers
26
+ )
27
+
28
+ # Define request model with Pydantic
29
+ class PredictionRequest(BaseModel):
30
+ origin: str
31
+ destination: str
32
+ granularity: Optional[str] = "quarter"
33
+ futureYear: Optional[int] = None
34
+ weeksAhead: Optional[int] = None
35
+ start_month: Optional[Union[int, str]] = None
36
+ end_month: Optional[Union[int, str]] = None
37
+ carrier: Optional[str] = None
38
+
39
+ # Download models on startup if they don't exist
40
+ @app.on_event("startup")
41
+ async def download_models():
42
+ models = ["flight_fare_rf_model.joblib", "flight_fare_ts_model.joblib"]
43
+ # Replace with your actual username
44
+ repo_id = "your-username/flightsavvy-models"
45
+
46
+ for model in models:
47
+ if not os.path.exists(model):
48
+ try:
49
+ print(f"Downloading {model} from Hugging Face...")
50
+ hf_hub_download(repo_id=repo_id, filename=model, local_dir=".")
51
+ print(f"Downloaded {model} successfully")
52
+ except Exception as e:
53
+ print(f"Error downloading {model}: {e}")
54
+ # Continue even if download fails - prediction.py has fallbacks
55
+
56
+ @app.post("/api/predict")
57
+ async def predict(request: PredictionRequest):
58
+ try:
59
+ # Convert month names to numbers if necessary
60
+ months = ['January', 'February', 'March', 'April', 'May', 'June',
61
+ 'July', 'August', 'September', 'October', 'November', 'December']
62
+
63
+ start_month = request.start_month
64
+ if isinstance(start_month, str) and not start_month.isdigit():
65
+ try:
66
+ start_month = months.index(start_month) + 1
67
+ except ValueError:
68
+ pass
69
+
70
+ end_month = request.end_month
71
+ if isinstance(end_month, str) and not end_month.isdigit():
72
+ try:
73
+ end_month = months.index(end_month) + 1
74
+ except ValueError:
75
+ pass
76
+
77
+ # Call prediction function with parameters from request
78
+ result = predict_best_time_to_buy_ticket(
79
+ origin=request.origin,
80
+ destination=request.destination,
81
+ granularity=request.granularity,
82
+ future_year=request.futureYear,
83
+ weeks_ahead=request.weeksAhead,
84
+ start_month=start_month,
85
+ end_month=end_month,
86
+ carrier=request.carrier
87
+ )
88
+ return result
89
+ except Exception as e:
90
+ # Log the error
91
+ print(f"API Error: {str(e)}")
92
+ print(traceback.format_exc())
93
+ # Return error response
94
+ raise HTTPException(status_code=500, detail=str(e))
95
+
96
+ # If running directly, start the server
97
+ if __name__ == "__main__":
98
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)
data/US Airline Flight Routes and Fares 1993-2024.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51f54079e9e089f9eb1ed4795c983b3cdf12d04e70036cdcc7aa18a8d3828937
3
+ size 63039765
prediction.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+
3
+ matplotlib.use('Agg') # Agg backend for non-GUI environments
4
+
5
+ import calendar
6
+ import datetime
7
+ import json
8
+ import logging
9
+ import random
10
+
11
+ import joblib
12
+ import numpy as np
13
+ import pandas as pd
14
+ from dotenv import load_dotenv
15
+
16
+ load_dotenv()
17
+
18
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
19
+
20
+ months = ['January', 'February', 'March', 'April', 'May', 'June',
21
+ 'July', 'August', 'September', 'October', 'November', 'December']
22
+
23
+ CARRIER_CATEGORIES = {
24
+ # Premium/Legacy Carriers (15-30% more expensive)
25
+ 'PREMIUM': {
26
+ 'AA': {'name': 'American Airlines', 'factor': 1.20},
27
+ 'DL': {'name': 'Delta Air Lines', 'factor': 1.25},
28
+ 'UA': {'name': 'United Airlines', 'factor': 1.18},
29
+ 'AS': {'name': 'Alaska Airlines', 'factor': 1.15},
30
+ 'US': {'name': 'US Airways (merged with AA)', 'factor': 1.17},
31
+ 'CO': {'name': 'Continental (merged with UA)', 'factor': 1.18},
32
+ 'NW': {'name': 'Northwest (merged with DL)', 'factor': 1.20},
33
+ 'TW': {'name': 'Trans World Airlines', 'factor': 1.15},
34
+ 'PA': {'name': 'Pan Am', 'factor': 1.20},
35
+ },
36
+
37
+ # Mid-tier Carriers (base price to 10% more)
38
+ 'MID_TIER': {
39
+ 'B6': {'name': 'JetBlue Airways', 'factor': 1.08},
40
+ 'WN': {'name': 'Southwest Airlines', 'factor': 1.00},
41
+ 'SY': {'name': 'Sun Country Airlines', 'factor': 1.03},
42
+ 'FL': {'name': 'AirTran Airways', 'factor': 1.02},
43
+ 'VX': {'name': 'Virgin America', 'factor': 1.10},
44
+ 'HP': {'name': 'America West Airlines', 'factor': 1.05},
45
+ 'AQ': {'name': 'Aloha Airlines', 'factor': 1.05},
46
+ 'QX': {'name': 'Horizon Air', 'factor': 1.05},
47
+ },
48
+
49
+ # Budget Carriers (15-30% less expensive)
50
+ 'BUDGET': {
51
+ 'NK': {'name': 'Spirit Airlines', 'factor': 0.75},
52
+ 'F9': {'name': 'Frontier Airlines', 'factor': 0.70},
53
+ 'G4': {'name': 'Allegiant Air', 'factor': 0.80},
54
+ 'HQ': {'name': 'Harmony Airways', 'factor': 0.85},
55
+ 'JI': {'name': 'Midway Airlines', 'factor': 0.85},
56
+ 'TZ': {'name': 'ATA Airlines', 'factor': 0.80},
57
+ 'WV': {'name': 'Air South', 'factor': 0.75},
58
+ 'BF': {'name': 'Markair', 'factor': 0.80},
59
+ 'SX': {'name': 'Skybus Airlines', 'factor': 0.65},
60
+ },
61
+
62
+ # Regional Carriers (5-15% less expensive)
63
+ 'REGIONAL': {
64
+ 'OO': {'name': 'SkyWest Airlines', 'factor': 0.90},
65
+ 'YX': {'name': 'Republic Airways', 'factor': 0.90},
66
+ 'YV': {'name': 'Mesa Airlines', 'factor': 0.92},
67
+ 'DH': {'name': 'Independence Air', 'factor': 0.88},
68
+ 'OH': {'name': 'PSA Airlines', 'factor': 0.90},
69
+ 'ZW': {'name': 'Air Wisconsin', 'factor': 0.90},
70
+ 'KS': {'name': 'Peninsula Airways', 'factor': 0.88},
71
+ '9K': {'name': 'Cape Air', 'factor': 0.85},
72
+ 'XJ': {'name': 'Mesaba Airlines', 'factor': 0.88},
73
+ 'RP': {'name': 'Chautauqua Airlines', 'factor': 0.90},
74
+ 'P9': {'name': 'Colgan Air', 'factor': 0.90},
75
+ 'ZV': {'name': 'Air Midwest', 'factor': 0.88},
76
+ },
77
+
78
+ # International Carriers
79
+ 'INTERNATIONAL': {
80
+ '3M': {'name': 'LATAM Airlines (formerly LAN)', 'factor': 1.22},
81
+ 'MX': {'name': 'Mexicana Airlines', 'factor': 1.10},
82
+ 'XP': {'name': 'XpressAir', 'factor': 1.05},
83
+ '5J': {'name': 'Cebu Pacific', 'factor': 0.90},
84
+ 'UK': {'name': 'Vistara', 'factor': 1.15},
85
+ 'KW': {'name': 'Korea Express Air', 'factor': 1.20},
86
+ 'KP': {'name': 'ASKY Airlines', 'factor': 1.10},
87
+ },
88
+
89
+ # Miscellaneous/Charter/Smaller Carriers
90
+ 'OTHER': {
91
+ 'RU': {'name': 'AirBridgeCargo', 'factor': 1.00},
92
+ 'J7': {'name': 'ValueJet', 'factor': 0.85},
93
+ 'U5': {'name': 'USA 3000 Airlines', 'factor': 0.90},
94
+ 'N7': {'name': 'National Airlines', 'factor': 1.00},
95
+ 'NJ': {'name': 'Visionair', 'factor': 0.95},
96
+ 'QQ': {'name': 'Reno Air', 'factor': 0.95},
97
+ 'W7': {'name': 'Western Pacific Airlines', 'factor': 0.93},
98
+ 'FF': {'name': 'Tower Air', 'factor': 0.90},
99
+ 'TB': {'name': 'USAir Shuttle', 'factor': 1.10},
100
+ 'LC': {'name': 'Logging Air', 'factor': 1.05},
101
+ 'YY': {'name': 'American Connection', 'factor': 0.95},
102
+ 'KN': {'name': 'China United Airlines', 'factor': 1.10},
103
+ 'E9': {'name': 'Evelop Airlines', 'factor': 1.05},
104
+ 'PN': {'name': 'Pan American Airways', 'factor': 1.10},
105
+ '9N': {'name': 'Northern Thunderbird Air', 'factor': 1.00},
106
+ 'U2': {'name': 'easyJet', 'factor': 0.85},
107
+ 'OE': {'name': 'Asia Overnight Express', 'factor': 1.05},
108
+ 'W9': {'name': 'Eastwind Airlines', 'factor': 0.90},
109
+ 'RL': {'name': 'Royal Airlines', 'factor': 1.10},
110
+ 'T3': {'name': 'Eastern Airways', 'factor': 1.00},
111
+ 'OP': {'name': 'Chalk\'s Ocean Airways', 'factor': 1.10},
112
+ 'ZA': {'name': 'Access Air', 'factor': 0.95},
113
+ }
114
+ }
115
+
116
+ # Base prices for popular routes
117
+ BASE_ROUTE_PRICES = {
118
+ 'ABQ-AUS': 95.00, # Albuquerque to Austin
119
+ 'LAX-JFK': 250.00, # Los Angeles to New York
120
+ 'ORD-DFW': 140.00, # Chicago to Dallas
121
+ 'ATL-LAS': 175.00, # Atlanta to Las Vegas
122
+ 'SFO-SEA': 120.00, # San Francisco to Seattle
123
+ 'DFW-LAX': 150.00, # Dallas to Los Angeles
124
+ 'DEN-PHX': 110.00, # Denver to Phoenix
125
+ 'MIA-JFK': 130.00, # Miami to New York
126
+ 'BOS-ORD': 120.00, # Boston to Chicago
127
+ 'SEA-LAS': 95.00, # Seattle to Las Vegas
128
+ }
129
+
130
+ def load_model_with_fallback(model_name):
131
+ """
132
+ Load a model with fallback mechanisms if the file isn't found locally.
133
+ """
134
+ try:
135
+ # First try to load locally
136
+ model = joblib.load(model_name)
137
+ print(f"Successfully loaded {model_name} from local path")
138
+ return model
139
+ except FileNotFoundError:
140
+ try:
141
+ # Try to download from Hugging Face
142
+ from huggingface_hub import hf_hub_download
143
+
144
+ # Replace with your actual username
145
+ repo_id = "Arsive/flight-fare-prediction"
146
+ hf_path = hf_hub_download(repo_id=repo_id, filename=model_name)
147
+ model = joblib.load(hf_path)
148
+ print(f"Successfully loaded {model_name} from Hugging Face Hub")
149
+ return model
150
+ except Exception as e:
151
+ print(f"Error loading model {model_name}: {e}")
152
+ print("Using fallback dummy model")
153
+ # Return a dummy model for demonstration purposes
154
+ from sklearn.ensemble import RandomForestRegressor
155
+ dummy_model = RandomForestRegressor()
156
+ dummy_model.fit([[0]], [0]) # Fit with dummy data
157
+ return dummy_model
158
+
159
+ # Function to get carrier information and pricing factor
160
+ def get_carrier_info(carrier_code):
161
+ """
162
+ Return the carrier info including name and pricing factor.
163
+ If carrier not found, returns default values.
164
+ """
165
+ if not carrier_code or carrier_code == 'nan':
166
+ return {'name': 'Unknown', 'factor': 1.0}
167
+
168
+ for category, carriers in CARRIER_CATEGORIES.items():
169
+ if carrier_code in carriers:
170
+ return {
171
+ 'name': carriers[carrier_code]['name'],
172
+ 'factor': carriers[carrier_code]['factor'],
173
+ 'category': category
174
+ }
175
+
176
+ # If carrier not found in any category
177
+ return {'name': f'Carrier {carrier_code}', 'factor': 1.0, 'category': 'UNKNOWN'}
178
+
179
+ import numpy as np
180
+
181
+
182
+ # convert_numpy_types function to explicitly handle bool_ types
183
+ def convert_numpy_types(obj):
184
+ """
185
+ Convert numpy types to native Python types for JSON serialization
186
+ """
187
+ if isinstance(obj, np.integer):
188
+ return int(obj)
189
+ elif isinstance(obj, np.floating):
190
+ return float(obj)
191
+ elif isinstance(obj, np.ndarray):
192
+ return obj.tolist()
193
+ elif isinstance(obj, np.bool_): # Add explicit handling for NumPy boolean type
194
+ return bool(obj)
195
+ elif isinstance(obj, datetime.date):
196
+ return obj.isoformat()
197
+ elif isinstance(obj, (dict, pd.Series)):
198
+ return {k: convert_numpy_types(v) for k, v in obj.items()}
199
+ elif isinstance(obj, list):
200
+ return [convert_numpy_types(item) for item in obj]
201
+ else:
202
+ return obj
203
+
204
+ # Function to adjust fare based on carrier and route
205
+ def adjust_fare_by_carrier(fare, carrier_code, route=None):
206
+ """
207
+ Adjust fare based on carrier and optionally the specific route.
208
+ """
209
+ # Get carrier info with pricing factor
210
+ carrier_info = get_carrier_info(carrier_code)
211
+ carrier_factor = carrier_info['factor']
212
+
213
+ route_factor = 1.0
214
+ if route and route in BASE_ROUTE_PRICES:
215
+ # Some carriers may have special pricing on specific routes
216
+ if carrier_code == 'WN' and route in ['DAL-HOU', 'LAS-PHX', 'ABQ-AUS']:
217
+ route_factor = 0.90 # Southwest cheaper on their hub routes
218
+ elif carrier_code == 'DL' and route in ['ATL-JFK', 'DTW-MSP']:
219
+ route_factor = 0.95 # Delta cheaper on their hub routes
220
+ elif carrier_code == 'F9' and route in ['DEN-LAS', 'DEN-PHX']:
221
+ route_factor = 0.85 # Frontier cheaper from Denver
222
+
223
+ # Apply small random variation (±5%)
224
+ variation_factor = random.uniform(0.95, 1.05)
225
+
226
+ # Calculate final adjusted fare with carrier, route, and variation factors
227
+ adjusted_fare = fare * carrier_factor * route_factor * variation_factor
228
+
229
+ return round(adjusted_fare, 2)
230
+
231
+ def predict_best_time_to_buy_ticket(origin, destination, granularity="quarter",
232
+ future_year=None, filepath=None,
233
+ weeks_ahead=None, start_month=None, end_month=None,
234
+ carrier=None):
235
+ """
236
+ Use Hugging Face-hosted models to predict the best time to buy a ticket
237
+
238
+ Parameters:
239
+ -----------
240
+ origin : str
241
+ Origin airport code (e.g., 'ABQ')
242
+ destination : str
243
+ Destination airport code (e.g., 'PHX')
244
+ granularity : str, optional
245
+ Prediction granularity: "date", "week", "month", or "quarter"
246
+ future_year : int, optional
247
+ Year to predict for (defaults to current year)
248
+ filepath : str, optional
249
+ Path to sample data for feature extraction
250
+ weeks_ahead : int, optional
251
+ If predicting for specific dates, how many weeks ahead to predict
252
+ start_month : int, optional
253
+ Start month of travel period (1-12)
254
+ end_month : int, optional
255
+ End month of travel period (1-12)
256
+ carrier : str, optional
257
+ Airline/carrier code to filter results (e.g., 'AA' for American Airlines)
258
+
259
+ Returns:
260
+ --------
261
+ dict
262
+ Contains best_time, predictions, and chart data for visualization
263
+ """
264
+ try:
265
+ origin = origin.upper()
266
+ destination = destination.upper()
267
+ route_name = f"{origin}-{destination}"
268
+
269
+ if future_year is None:
270
+ future_year = datetime.datetime.now().year
271
+ else:
272
+ future_year = int(future_year)
273
+
274
+ print(f"Predicting best time to buy for route: {route_name} with granularity: {granularity}")
275
+ if start_month and end_month:
276
+ print(f"Travel period: Months {start_month} to {end_month}")
277
+
278
+ carrier_info = None
279
+ if carrier:
280
+ carrier_info = get_carrier_info(carrier)
281
+ print(f"Filtering for carrier: {carrier} ({carrier_info['name']})")
282
+ print(f"Carrier pricing factor: {carrier_info['factor']}")
283
+
284
+ try:
285
+ rf_model = load_model_with_fallback('flight_fare_rf_model.joblib')
286
+ except Exception as e:
287
+ print(f"Error loading Random Forest model: {e}")
288
+
289
+ try:
290
+ ts_model = load_model_with_fallback('flight_fare_ts_model.joblib')
291
+ except Exception as e:
292
+ print(f"Time series model not available: {str(e)}")
293
+ ts_model = None
294
+
295
+ if filepath is None:
296
+ filepath = 'data/US Airline Flight Routes and Fares 1993-2024.csv'
297
+
298
+ print(f"Loading data from {filepath}")
299
+ try:
300
+ df = pd.read_csv(filepath)
301
+ except Exception as e:
302
+ print(f"Error loading data file: {e}")
303
+ df = pd.DataFrame({
304
+ 'airport_1': ['DFW', 'LAX', 'ATL'],
305
+ 'airport_2': ['LAX', 'JFK', 'MIA'],
306
+ 'route': ['DFW-LAX', 'LAX-JFK', 'ATL-MIA'],
307
+ 'Year': [2024, 2024, 2024],
308
+ 'quarter': [1, 2, 3],
309
+ 'fare': [250, 350, 300],
310
+ 'nsmiles': [1200, 2500, 600],
311
+ 'passengers': [300, 400, 200],
312
+ 'carrier_lg': ['AA', 'DL', 'WN'],
313
+ 'large_ms': [0.8, 0.7, 0.6],
314
+ 'fare_lg': [280, 380, 320],
315
+ 'carrier_low': ['WN', 'UA', 'NK'],
316
+ 'lf_ms': [0.2, 0.3, 0.4],
317
+ 'fare_low': [200, 300, 250]
318
+ })
319
+
320
+ if 'route' not in df.columns:
321
+ print("Creating 'route' column from airport codes")
322
+ df['route'] = df['airport_1'] + '-' + df['airport_2']
323
+
324
+ route_data = df[df['route'] == route_name].copy()
325
+
326
+ # If no exact route match, find similar routes or use average values
327
+ if route_data.empty:
328
+ print(f"No data for exact route {route_name}, using similar routes or average values")
329
+ origin_routes = df[df['airport_1'] == origin]
330
+ if not origin_routes.empty:
331
+ route_data = origin_routes.iloc[0:1].copy()
332
+ print(f"Using data from route with same origin: {route_data['route'].values[0]}")
333
+ else:
334
+ route_data = df.iloc[0:1].copy()
335
+ print("Using average route data")
336
+
337
+ route_data['airport_1'] = origin
338
+ route_data['airport_2'] = destination
339
+ route_data['route'] = route_name
340
+
341
+ if route_name in BASE_ROUTE_PRICES:
342
+ print(f"Using base price for route {route_name}: ${BASE_ROUTE_PRICES[route_name]}")
343
+ route_data['fare'] = BASE_ROUTE_PRICES[route_name]
344
+ else:
345
+ if 'nsmiles' in route_data.columns:
346
+ estimated_fare = route_data['nsmiles'].values[0] * 0.15 # $0.15 per mile as base
347
+ print(f"Estimating fare based on distance: ${estimated_fare:.2f}")
348
+ route_data['fare'] = estimated_fare
349
+
350
+ if carrier and carrier_info:
351
+ base_fare = route_data['fare'].values[0]
352
+
353
+ adjusted_fare = adjust_fare_by_carrier(base_fare, carrier, route_name)
354
+
355
+ print(f"Adjusting fare for {carrier} ({carrier_info['name']}): ${base_fare:.2f} → ${adjusted_fare:.2f}")
356
+
357
+ route_data['fare'] = adjusted_fare
358
+
359
+ route_data['carrier'] = carrier
360
+ route_data['carrier_name'] = carrier_info['name']
361
+ route_data['carrier_category'] = carrier_info.get('category', 'UNKNOWN')
362
+
363
+ if 'carrier_lg' in route_data.columns and route_data['carrier_lg'].values[0] == carrier:
364
+ if 'fare_lg' in route_data.columns:
365
+ print(f"Using {carrier} as the major carrier with fare: ${route_data['fare_lg'].values[0]:.2f}")
366
+ route_data['fare'] = route_data['fare_lg']
367
+ elif 'carrier_low' in route_data.columns and route_data['carrier_low'].values[0] == carrier:
368
+ if 'fare_low' in route_data.columns:
369
+ print(f"Using {carrier} as the low-fare carrier with fare: ${route_data['fare_low'].values[0]:.2f}")
370
+ route_data['fare'] = route_data['fare_low']
371
+
372
+ print("Engineering required features")
373
+
374
+ if 'nsmiles' in route_data.columns and 'fare' in route_data.columns:
375
+ route_data['price_per_mile'] = route_data['fare'] / route_data['nsmiles']
376
+ else:
377
+ route_data['price_per_mile'] = 0.25 # Default average price per mile
378
+
379
+ if 'large_ms' in route_data.columns and 'lf_ms' in route_data.columns:
380
+ route_data['market_concentration'] = np.maximum(
381
+ route_data['large_ms'], route_data['lf_ms'])
382
+ else:
383
+ route_data['market_concentration'] = 0.8 # Default high concentration
384
+
385
+ if 'fare_lg' in route_data.columns and 'fare_low' in route_data.columns:
386
+ route_data['price_difference'] = route_data['fare_lg'] - route_data['fare_low']
387
+ else:
388
+ route_data['price_difference'] = 20.0 # Default difference
389
+
390
+ if 'carrier_lg' in route_data.columns:
391
+ # If we had access to all data, we'd group by route
392
+ route_data['route_competition'] = 2 # Default: assume 2 carriers
393
+ else:
394
+ route_data['route_competition'] = 2 # Default competition value
395
+
396
+ if 'season' not in route_data.columns:
397
+ print("Adding season column")
398
+ seasons = {1: 'Winter', 2: 'Spring', 3: 'Summer', 4: 'Fall'}
399
+ route_data['quarter'] = route_data['quarter'].astype(int)
400
+ route_data['season'] = route_data['quarter'].map(seasons)
401
+
402
+ required_columns = ['Year', 'quarter', 'nsmiles', 'passengers']
403
+ for col in required_columns:
404
+ if col not in route_data.columns:
405
+ # Add defaults if missing
406
+ if col == 'nsmiles':
407
+ route_data['nsmiles'] = 800 # Default distance
408
+ elif col == 'passengers':
409
+ route_data['passengers'] = 250 # Default passenger count
410
+
411
+ prediction_dates = []
412
+
413
+ if granularity == "date":
414
+ if weeks_ahead is None:
415
+ weeks_ahead = 12 # Default to 12 weeks (about 3 months) ahead
416
+
417
+ start_date = datetime.datetime.now().date()
418
+ for i in range(weeks_ahead * 7):
419
+ prediction_dates.append(start_date + datetime.timedelta(days=i))
420
+
421
+ if start_month is not None and end_month is not None:
422
+ filtered_dates = []
423
+
424
+ def is_in_travel_period(date):
425
+ month = date.month
426
+ if start_month <= end_month:
427
+ return start_month <= month <= end_month
428
+ else: # Wrap around case (e.g., November to February)
429
+ return month >= start_month or month <= end_month
430
+
431
+ for date in prediction_dates:
432
+ if is_in_travel_period(date):
433
+ filtered_dates.append(date)
434
+
435
+ if filtered_dates:
436
+ prediction_dates = filtered_dates
437
+
438
+ elif granularity == "week":
439
+ start_date = datetime.datetime(future_year, 1, 1)
440
+ while start_date.weekday() != 0: # 0 = Monday
441
+ start_date += datetime.timedelta(days=1)
442
+
443
+ current_date = start_date
444
+ while current_date.year == future_year:
445
+ prediction_dates.append(current_date.date())
446
+ current_date += datetime.timedelta(days=7)
447
+
448
+ if start_month is not None and end_month is not None:
449
+ filtered_dates = []
450
+
451
+ def is_in_travel_period(date):
452
+ month = date.month
453
+ if start_month <= end_month:
454
+ return start_month <= month <= end_month
455
+ else: # Wrap around case (e.g., November to February)
456
+ return month >= start_month or month <= end_month
457
+
458
+ for date in prediction_dates:
459
+ if is_in_travel_period(date):
460
+ filtered_dates.append(date)
461
+
462
+ if filtered_dates:
463
+ prediction_dates = filtered_dates
464
+
465
+ elif granularity == "month":
466
+ for month in range(1, 13):
467
+ prediction_dates.append(datetime.datetime(future_year, month, 1).date())
468
+
469
+ elif granularity == "quarter":
470
+ quarter_months = [2, 5, 8, 11] # February, May, August, November
471
+ for month in quarter_months:
472
+ prediction_dates.append(datetime.datetime(future_year, month, 15).date())
473
+
474
+ predictions = []
475
+
476
+ for pred_date in prediction_dates:
477
+ print(f"Generating prediction for {pred_date}")
478
+
479
+ sample_data = route_data.iloc[0].copy()
480
+
481
+ sample_data['Year'] = pred_date.year
482
+ sample_data['month'] = pred_date.month
483
+ sample_data['day_of_year'] = pred_date.timetuple().tm_yday
484
+
485
+ quarter = (pred_date.month - 1) // 3 + 1
486
+ sample_data['quarter'] = quarter
487
+
488
+ week_number = pred_date.isocalendar()[1]
489
+ sample_data['week'] = week_number
490
+
491
+ major_holidays = [
492
+ (1, 1), # New Year's
493
+ (12, 25), # Christmas
494
+ (11, [20, 21, 22, 23, 24, 25, 26, 27, 28]), # Thanksgiving range
495
+ (7, 4), # 4th of July
496
+ (5, [25, 26, 27, 28, 29, 30, 31]), # Memorial Day range
497
+ (9, [1, 2, 3, 4, 5, 6, 7]), # Labor Day range
498
+ ]
499
+
500
+ is_holiday = False
501
+ for month, days in major_holidays:
502
+ if pred_date.month == month:
503
+ if isinstance(days, list):
504
+ if pred_date.day in days:
505
+ is_holiday = True
506
+ break
507
+ elif pred_date.day == days:
508
+ is_holiday = True
509
+ break
510
+
511
+ if not is_holiday:
512
+ for month, days in major_holidays:
513
+ if isinstance(days, list):
514
+ holiday_date = datetime.datetime(pred_date.year, month, days[0])
515
+ else:
516
+ holiday_date = datetime.datetime(pred_date.year, month, days)
517
+
518
+ delta = abs((pred_date - holiday_date.date()).days)
519
+ if delta <= 14: # Within 2 weeks
520
+ is_holiday = True
521
+ break
522
+
523
+ sample_data['is_holiday_period'] = is_holiday
524
+
525
+ seasons_by_month = {
526
+ 1: 'Winter', 2: 'Winter', 3: 'Spring',
527
+ 4: 'Spring', 5: 'Spring', 6: 'Summer',
528
+ 7: 'Summer', 8: 'Summer', 9: 'Fall',
529
+ 10: 'Fall', 11: 'Fall', 12: 'Winter'
530
+ }
531
+ sample_data['season'] = seasons_by_month[pred_date.month]
532
+
533
+ if carrier:
534
+ sample_data['carrier'] = carrier
535
+ if carrier_info:
536
+ sample_data['carrier_name'] = carrier_info['name']
537
+ sample_data['carrier_category'] = carrier_info.get('category', 'UNKNOWN')
538
+
539
+ # Random Forest prediction
540
+ try:
541
+ sample_X = pd.DataFrame([sample_data])
542
+
543
+ rf_predicted_fare = rf_model.predict(sample_X)[0]
544
+ print(f"RF prediction: ${rf_predicted_fare:.2f}")
545
+ except Exception as e:
546
+ print(f"Error making Random Forest prediction: {str(e)}")
547
+ rf_predicted_fare = sample_data.get('fare', 180.0)
548
+ print(f"Using fallback fare: ${rf_predicted_fare:.2f}")
549
+
550
+ ts_predicted_fare = None
551
+ combined_prediction = rf_predicted_fare
552
+
553
+ if ts_model is not None:
554
+ try:
555
+
556
+ if granularity == "quarter":
557
+ ts_idx = quarter - 1
558
+ elif granularity == "month":
559
+ ts_idx = pred_date.month - 1
560
+ elif granularity == "week":
561
+ ts_idx = min(week_number - 1, 51) # Max 52 weeks
562
+ else:
563
+
564
+ days_in_year = 366 if calendar.isleap(pred_date.year) else 365
565
+ ts_idx = int((sample_data['day_of_year'] / days_in_year) * 4) # Scale to 0-3
566
+
567
+
568
+ max_steps = 52 if granularity == "week" else 12 if granularity == "month" else 4
569
+ forecasts = ts_model.forecast(steps=max_steps)
570
+ ts_predicted_fare = forecasts[min(ts_idx, len(forecasts)-1)]
571
+ print(f"TS prediction: ${ts_predicted_fare:.2f}")
572
+
573
+
574
+ combined_prediction = 0.7 * rf_predicted_fare + 0.3 * ts_predicted_fare
575
+ print(f"Combined prediction: ${combined_prediction:.2f}")
576
+ except Exception as e:
577
+ print(f"Error making time series prediction: {str(e)}")
578
+
579
+
580
+ holiday_markup = 1.15 if is_holiday else 1.0 # 15% markup for holiday periods
581
+
582
+ # Apply seasonality effect based on month
583
+ seasonal_factors = {
584
+ 1: 1.05, # January (post-holiday)
585
+ 2: 1.0, # February (low season)
586
+ 3: 1.02, # March (spring break)
587
+ 4: 1.05, # April (Easter)
588
+ 5: 1.02, # May
589
+ 6: 1.12, # June (summer peak)
590
+ 7: 1.15, # July (summer peak)
591
+ 8: 1.1, # August (summer end)
592
+ 9: 0.95, # September (low season)
593
+ 10: 0.98, # October (low season)
594
+ 11: 1.1, # November (Thanksgiving)
595
+ 12: 1.2 # December (Christmas)
596
+ }
597
+
598
+ seasonal_factor = seasonal_factors[pred_date.month]
599
+
600
+ # Apply day of week effect
601
+ dow_factors = {
602
+ 0: 0.98, # Monday
603
+ 1: 0.97, # Tuesday
604
+ 2: 0.97, # Wednesday
605
+ 3: 1.02, # Thursday
606
+ 4: 1.05, # Friday
607
+ 5: 1.02, # Saturday
608
+ 6: 0.99 # Sunday
609
+ }
610
+
611
+ dow_factor = dow_factors[pred_date.weekday()]
612
+
613
+ final_prediction = combined_prediction * holiday_markup * seasonal_factor * dow_factor
614
+ print(f"Final prediction with all effects: ${final_prediction:.2f}")
615
+
616
+ if carrier and carrier_info:
617
+ carrier_adjusted_prediction = adjust_fare_by_carrier(final_prediction, carrier, route_name)
618
+
619
+ carrier_effect_ratio = carrier_adjusted_prediction / final_prediction
620
+ if abs(carrier_effect_ratio - 1.0) > 0.05: # If more than 5% difference
621
+ print(f"Applying carrier effect: ${final_prediction:.2f} → ${carrier_adjusted_prediction:.2f}")
622
+ final_prediction = carrier_adjusted_prediction
623
+
624
+ prediction_entry = {
625
+ 'date': pred_date,
626
+ 'predicted_fare': float(final_prediction),
627
+ 'rf_prediction': float(rf_predicted_fare) if rf_predicted_fare is not None else None,
628
+ 'ts_prediction': float(ts_predicted_fare) if ts_predicted_fare is not None else None,
629
+ 'is_holiday_period': bool(is_holiday),
630
+ 'year': int(pred_date.year),
631
+ 'month': int(pred_date.month),
632
+ 'month_name': pred_date.strftime('%B'),
633
+ 'quarter': int(quarter),
634
+ 'week': int(week_number),
635
+ 'day_of_week': pred_date.strftime('%A'),
636
+ 'carrier': carrier if carrier else None,
637
+ 'carrier_name': carrier_info['name'] if carrier_info else None
638
+ }
639
+
640
+ predictions.append(prediction_entry)
641
+
642
+ predictions_df = pd.DataFrame(predictions)
643
+
644
+ filtered_predictions_df = predictions_df.copy()
645
+ travel_period_filtered = False
646
+
647
+ if start_month is not None and end_month is not None and granularity in ["month", "quarter"]:
648
+ travel_period_filtered = True
649
+
650
+ def is_in_travel_period(month):
651
+ if start_month <= end_month:
652
+ return start_month <= month <= end_month
653
+ else: # Wrap around case (e.g., November to February)
654
+ return month >= start_month or month <= end_month
655
+
656
+ # Filter predictions by travel period
657
+ if granularity == "month":
658
+ filtered_predictions_df = predictions_df[predictions_df['month'].apply(is_in_travel_period)]
659
+ elif granularity == "quarter":
660
+ # Map months to quarters
661
+ start_quarter = (start_month - 1) // 3 + 1
662
+ end_quarter = (end_month - 1) // 3 + 1
663
+
664
+ def is_quarter_in_travel_period(quarter):
665
+ if start_quarter <= end_quarter:
666
+ return start_quarter <= quarter <= end_quarter
667
+ else: # Wrap around case
668
+ return quarter >= start_quarter or quarter <= end_quarter
669
+
670
+ filtered_predictions_df = predictions_df[predictions_df['quarter'].apply(is_quarter_in_travel_period)]
671
+
672
+ if travel_period_filtered and not filtered_predictions_df.empty:
673
+ active_df = filtered_predictions_df
674
+ else:
675
+ active_df = predictions_df
676
+
677
+ # Find the best time at the specified granularity
678
+ if granularity == "date":
679
+ best_idx = active_df['predicted_fare'].idxmin()
680
+ best_time_row = active_df.loc[best_idx]
681
+ best_time = {k: convert_numpy_types(v) for k, v in best_time_row.to_dict().items()}
682
+
683
+ if isinstance(best_time['date'], str):
684
+ date_obj = datetime.datetime.strptime(best_time['date'], '%Y-%m-%d').date()
685
+ formatted_best = f"{date_obj.strftime('%A, %B %d, %Y')}"
686
+ else:
687
+ formatted_best = f"{best_time['date'].strftime('%A, %B %d, %Y')}"
688
+
689
+ elif granularity == "week":
690
+ # Group by week
691
+ weekly_avg = active_df.groupby('week')['predicted_fare'].mean().reset_index()
692
+ best_week = int(weekly_avg.loc[weekly_avg['predicted_fare'].idxmin()]['week'])
693
+ best_week_data = active_df[active_df['week'] == best_week].iloc[0]
694
+ best_time = {k: convert_numpy_types(v) for k, v in best_week_data.to_dict().items()}
695
+
696
+ best_date = best_time['date']
697
+ if isinstance(best_date, str):
698
+ best_date = datetime.datetime.strptime(best_date, '%Y-%m-%d').date()
699
+ start_of_week = best_date - datetime.timedelta(days=best_date.weekday())
700
+ end_of_week = start_of_week + datetime.timedelta(days=6)
701
+ formatted_best = f"Week {best_week} ({start_of_week.strftime('%b %d')} - {end_of_week.strftime('%b %d')})"
702
+
703
+ elif granularity == "month":
704
+ monthly_avg = active_df.groupby(['month', 'month_name'])['predicted_fare'].mean().reset_index()
705
+ best_month_idx = monthly_avg['predicted_fare'].idxmin()
706
+ best_month = int(monthly_avg.loc[best_month_idx]['month'])
707
+ best_month_name = monthly_avg.loc[best_month_idx]['month_name']
708
+ best_month_fare = float(monthly_avg.loc[best_month_idx]['predicted_fare'])
709
+
710
+ best_time = {
711
+ 'month': best_month,
712
+ 'month_name': best_month_name,
713
+ 'predicted_fare': best_month_fare,
714
+ 'carrier': carrier,
715
+ 'carrier_name': carrier_info['name'] if carrier_info else None
716
+ }
717
+ formatted_best = f"{best_month_name}"
718
+
719
+ elif granularity == "quarter":
720
+ quarterly_avg = active_df.groupby('quarter')['predicted_fare'].mean().reset_index()
721
+ best_quarter_idx = quarterly_avg['predicted_fare'].idxmin()
722
+ best_quarter = int(quarterly_avg.loc[best_quarter_idx]['quarter'])
723
+ best_quarter_fare = float(quarterly_avg.loc[best_quarter_idx]['predicted_fare'])
724
+
725
+ best_time = {
726
+ 'quarter': best_quarter,
727
+ 'predicted_fare': best_quarter_fare,
728
+ 'carrier': carrier,
729
+ 'carrier_name': carrier_info['name'] if carrier_info else None
730
+ }
731
+ formatted_best = f"Q{best_quarter}"
732
+
733
+ viz_df = filtered_predictions_df if travel_period_filtered and not filtered_predictions_df.empty else predictions_df
734
+
735
+ chart_data = {}
736
+
737
+ if granularity == "date":
738
+ chart_data = {
739
+ 'type': 'line',
740
+ 'data': [
741
+ {
742
+ 'date': pred['date'].isoformat() if not isinstance(pred['date'], str) else pred['date'],
743
+ 'fare': round(pred['predicted_fare'], 2),
744
+ 'isHoliday': pred['is_holiday_period'],
745
+ 'isBest': False
746
+ }
747
+ for pred in viz_df.to_dict('records')
748
+ ],
749
+ 'xAxisKey': 'date',
750
+ 'yAxisKey': 'fare',
751
+ 'xAxisLabel': 'Date',
752
+ 'yAxisLabel': 'Predicted Fare ($)',
753
+ 'title': f'Predicted Fares for {route_name}'
754
+ }
755
+
756
+ if carrier and carrier_info:
757
+ chart_data['title'] += f' with {carrier} ({carrier_info["name"]})'
758
+
759
+ if travel_period_filtered:
760
+ chart_data['title'] += f' (Travel Period: {months[start_month-1]} to {months[end_month-1]})'
761
+
762
+ best_idx = viz_df['predicted_fare'].idxmin()
763
+ best_date = viz_df.loc[best_idx, 'date']
764
+ best_date_str = best_date.isoformat() if not isinstance(best_date, str) else best_date
765
+
766
+ for point in chart_data['data']:
767
+ if point['date'] == best_date_str:
768
+ point['isBest'] = True
769
+ break
770
+
771
+ elif granularity in ["week", "month", "quarter"]:
772
+ if granularity == "week":
773
+ # Group by week
774
+ grouped_data = viz_df.groupby('week')['predicted_fare'].mean().reset_index()
775
+ label_key = 'week'
776
+ label_formatter = lambda x: f"Week {int(x)}"
777
+
778
+ elif granularity == "month":
779
+ # Group by month
780
+ grouped_data = viz_df.groupby(['month', 'month_name'])['predicted_fare'].mean().reset_index()
781
+ grouped_data = grouped_data.sort_values('month')
782
+ label_key = 'month_name'
783
+ label_formatter = lambda x: x
784
+
785
+ else: # Quarter
786
+ # Group by quarter
787
+ grouped_data = viz_df.groupby('quarter')['predicted_fare'].mean().reset_index()
788
+ label_key = 'quarter'
789
+ label_formatter = lambda x: f"Q{int(x)}"
790
+
791
+ # Find best time period
792
+ best_idx = grouped_data['predicted_fare'].idxmin()
793
+ best_value = grouped_data.loc[best_idx, label_key]
794
+
795
+ chart_data = {
796
+ 'type': 'bar',
797
+ 'data': [
798
+ {
799
+ 'label': label_formatter(row[label_key]),
800
+ 'value': label_key,
801
+ 'originalValue': row[label_key],
802
+ 'fare': round(row['predicted_fare'], 2),
803
+ 'isBest': row[label_key] == best_value
804
+ }
805
+ for _, row in grouped_data.iterrows()
806
+ ],
807
+ 'xAxisKey': 'label',
808
+ 'yAxisKey': 'fare',
809
+ 'xAxisLabel': 'Time Period',
810
+ 'yAxisLabel': 'Predicted Fare ($)',
811
+ 'title': f'Predicted Fares for {route_name}'
812
+ }
813
+
814
+ if carrier and carrier_info:
815
+ chart_data['title'] += f' with {carrier} ({carrier_info["name"]})'
816
+
817
+ if travel_period_filtered:
818
+ chart_data['title'] += f' (Travel Period: {months[start_month-1]} to {months[end_month-1]})'
819
+
820
+ full_analysis_chart_data = {}
821
+
822
+ if travel_period_filtered and (granularity == "month" or granularity == "quarter"):
823
+ if granularity == "month":
824
+ full_grouped_data = predictions_df.groupby(['month', 'month_name'])['predicted_fare'].mean().reset_index()
825
+ full_grouped_data = full_grouped_data.sort_values('month')
826
+ label_key = 'month_name'
827
+ label_formatter = lambda x: x
828
+
829
+ def is_in_travel_period(month):
830
+ if start_month <= end_month:
831
+ return start_month <= month <= end_month
832
+ else:
833
+ return month >= start_month or month <= end_month
834
+
835
+ best_month = int(monthly_avg.loc[monthly_avg['predicted_fare'].idxmin()]['month'])
836
+
837
+ else: # quarter
838
+ full_grouped_data = predictions_df.groupby('quarter')['predicted_fare'].mean().reset_index()
839
+ label_key = 'quarter'
840
+ label_formatter = lambda x: f"Q{int(x)}"
841
+
842
+ def is_quarter_in_travel_period(quarter):
843
+ start_quarter = (start_month - 1) // 3 + 1
844
+ end_quarter = (end_month - 1) // 3 + 1
845
+ if start_quarter <= end_quarter:
846
+ return start_quarter <= quarter <= end_quarter
847
+ else:
848
+ return quarter >= start_quarter or quarter <= end_quarter
849
+
850
+ best_quarter = int(quarterly_avg.loc[quarterly_avg['predicted_fare'].idxmin()]['quarter'])
851
+
852
+ full_analysis_chart_data = {
853
+ 'type': 'bar',
854
+ 'data': [],
855
+ 'xAxisKey': 'label',
856
+ 'yAxisKey': 'fare',
857
+ 'xAxisLabel': 'Time Period',
858
+ 'yAxisLabel': 'Predicted Fare ($)',
859
+ 'title': f'Full Year Price Analysis for {route_name}'
860
+ }
861
+
862
+ if carrier and carrier_info:
863
+ full_analysis_chart_data['title'] += f' with {carrier} ({carrier_info["name"]})'
864
+
865
+ full_analysis_chart_data['title'] += f' (Travel Period: {months[start_month-1]} to {months[end_month-1]})'
866
+
867
+ for _, row in full_grouped_data.iterrows():
868
+ original_value = row[label_key]
869
+ if granularity == "month":
870
+ in_travel_period = is_in_travel_period(row['month'])
871
+ is_best = row['month'] == best_month
872
+ else: # quarter
873
+ in_travel_period = is_quarter_in_travel_period(row['quarter'])
874
+ is_best = row['quarter'] == best_quarter
875
+
876
+ full_analysis_chart_data['data'].append({
877
+ 'label': label_formatter(original_value),
878
+ 'value': original_value,
879
+ 'fare': round(row['predicted_fare'], 2),
880
+ 'inTravelPeriod': in_travel_period,
881
+ 'isBest': is_best
882
+ })
883
+ else:
884
+ full_analysis_chart_data = chart_data
885
+
886
+ filtered_predictions = []
887
+ for pred in active_df.to_dict('records'):
888
+ filtered_predictions.append({k: convert_numpy_types(v) for k, v in pred.items()})
889
+
890
+ all_predictions = []
891
+ for pred in predictions_df.to_dict('records'):
892
+ all_predictions.append({k: convert_numpy_types(v) for k, v in pred.items()})
893
+
894
+ result = {
895
+ 'route': route_name,
896
+ 'granularity': granularity,
897
+ 'carrier': carrier,
898
+ 'carrier_name': carrier_info['name'] if carrier_info else None,
899
+ 'best_time': best_time,
900
+ 'formatted_best_time': formatted_best,
901
+ 'filtered_predictions': filtered_predictions,
902
+ 'all_predictions': all_predictions,
903
+ 'chart_data': chart_data,
904
+ 'full_analysis_chart_data': full_analysis_chart_data,
905
+ 'travel_period': {
906
+ 'start_month': start_month,
907
+ 'end_month': end_month,
908
+ 'start_month_name': months[start_month-1] if start_month is not None else None,
909
+ 'end_month_name': months[end_month-1] if end_month is not None else None
910
+ } if start_month and end_month else None,
911
+ 'travel_period_filtered': travel_period_filtered,
912
+ 'success': True
913
+ }
914
+
915
+ result = json.loads(json.dumps(result, default=lambda o: convert_numpy_types(o)))
916
+
917
+ return result
918
+
919
+ except Exception as e:
920
+ import traceback
921
+ print(f"Error in prediction: {e}")
922
+ print(traceback.format_exc())
923
+ return {
924
+ 'error': str(e),
925
+ 'success': False
926
+ }
927
+
928
+ if __name__ == "__main__":
929
+ print("\n=== PREDICTING BY QUARTER ===")
930
+ result_quarter = predict_best_time_to_buy_ticket('ABQ', 'PHX', granularity="quarter")
931
+
932
+ print("\n=== PREDICTING BY MONTH ===")
933
+ result_month = predict_best_time_to_buy_ticket('ABQ', 'PHX', granularity="month")
934
+
935
+ print("\n=== PREDICTING BY MONTH WITH TRAVEL PERIOD ===")
936
+ result_month_filtered = predict_best_time_to_buy_ticket('ABQ', 'PHX', granularity="month", start_month=4, end_month=8)
937
+
938
+ print("\n=== PREDICTING BY WEEK ===")
939
+ result_week = predict_best_time_to_buy_ticket('ABQ', 'PHX', granularity="week")
940
+
941
+ print("\n=== PREDICTING BY DATE ===")
942
+ result_date = predict_best_time_to_buy_ticket('ABQ', 'PHX', granularity="date", weeks_ahead=8)
943
+
944
+ print("\n=== PREDICTING WITH SPECIFIC CARRIER ===")
945
+ result_carrier = predict_best_time_to_buy_ticket('ABQ', 'PHX', granularity="month", carrier="WN")
946
+
947
+ carriers_to_test = ['AA', 'DL', 'WN', 'F9', 'NK', 'G4']
948
+ for test_carrier in carriers_to_test:
949
+ print(f"\n=== TESTING WITH CARRIER: {test_carrier} ===")
950
+ result = predict_best_time_to_buy_ticket('ABQ', 'PHX', granularity="month", carrier=test_carrier)
951
+ if result.get('success', False):
952
+ print(f"Carrier: {test_carrier}")
953
+ print(f"Predicted fare: ${result['best_time']['predicted_fare']:.2f}")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pandas
4
+ numpy
5
+ joblib
6
+ matplotlib
7
+ python-dateutil
8
+ scikit-learn
9
+ pydantic
10
+ huggingface_hub
11
+ python-dotenv
space.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ title: FlightSavvy API
2
+ emoji: ✈️
3
+ colorFrom: blue
4
+ colorTo: purple
5
+ sdk: docker
6
+ pinned: false