flight-fare-hf / app.py
Arsive2's picture
updated app.py
689c872
import os
import traceback
from typing import Optional, Union
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import hf_hub_download
from pydantic import BaseModel
# Import the prediction function
from prediction import predict_best_time_to_buy_ticket
# Create FastAPI instance
app = FastAPI(title="FlightSavvy API",
description="API for predicting the best time to buy flight tickets",
version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# Define request model with Pydantic
class PredictionRequest(BaseModel):
origin: str
destination: str
granularity: Optional[str] = "quarter"
futureYear: Optional[int] = None
weeksAhead: Optional[int] = None
start_month: Optional[Union[int, str]] = None
end_month: Optional[Union[int, str]] = None
carrier: Optional[str] = None
# Download models on startup if they don't exist
@app.on_event("startup")
async def download_models():
models = ["flight_fare_rf_model.joblib", "flight_fare_ts_model.joblib"]
# Replace with your actual username
repo_id = "Arsive/flight-fare-prediction"
for model in models:
if not os.path.exists(model):
try:
print(f"Downloading {model} from Hugging Face...")
hf_hub_download(repo_id=repo_id, filename=model, local_dir=".")
print(f"Downloaded {model} successfully")
except Exception as e:
print(f"Error downloading {model}: {e}")
# Continue even if download fails - prediction.py has fallbacks
@app.post("/api/predict")
async def predict(request: PredictionRequest):
try:
# Convert month names to numbers if necessary
months = ['January', 'February', 'March', 'April', 'May', 'June',
'July', 'August', 'September', 'October', 'November', 'December']
start_month = request.start_month
if isinstance(start_month, str) and not start_month.isdigit():
try:
start_month = months.index(start_month) + 1
except ValueError:
pass
end_month = request.end_month
if isinstance(end_month, str) and not end_month.isdigit():
try:
end_month = months.index(end_month) + 1
except ValueError:
pass
# Call prediction function with parameters from request
result = predict_best_time_to_buy_ticket(
origin=request.origin,
destination=request.destination,
granularity=request.granularity,
future_year=request.futureYear,
weeks_ahead=request.weeksAhead,
start_month=start_month,
end_month=end_month,
carrier=request.carrier
)
return result
except Exception as e:
# Log the error
print(f"API Error: {str(e)}")
print(traceback.format_exc())
# Return error response
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "Welcome to the FlightSavvy API! Use /api/predict for predictions."}
@app.get("/help")
async def help():
return {
"message": "This API predicts the best time to buy flight tickets.",
"usage": {
"POST /api/predict": {
"description": "Predict the best time to buy flight tickets.",
"parameters": {
"origin": "Departure airport code",
"destination": "Arrival airport code",
"granularity": "Time granularity (e.g., 'quarter')",
"futureYear": "Year for prediction",
"weeksAhead": "Weeks ahead for prediction",
"start_month": "Start month (1-12 or month name)",
"end_month": "End month (1-12 or month name)",
"carrier": "Carrier code (optional)"
}
}
}
}
# If running directly, start the server
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)