Spaces:
Sleeping
Sleeping
| import os | |
| import traceback | |
| from typing import Optional, Union | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from huggingface_hub import hf_hub_download | |
| from pydantic import BaseModel | |
| # Import the prediction function | |
| from prediction import predict_best_time_to_buy_ticket | |
| # Create FastAPI instance | |
| app = FastAPI(title="FlightSavvy API", | |
| description="API for predicting the best time to buy flight tickets", | |
| version="1.0.0") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| # Define request model with Pydantic | |
| class PredictionRequest(BaseModel): | |
| origin: str | |
| destination: str | |
| granularity: Optional[str] = "quarter" | |
| futureYear: Optional[int] = None | |
| weeksAhead: Optional[int] = None | |
| start_month: Optional[Union[int, str]] = None | |
| end_month: Optional[Union[int, str]] = None | |
| carrier: Optional[str] = None | |
| # Download models on startup if they don't exist | |
| async def download_models(): | |
| models = ["flight_fare_rf_model.joblib", "flight_fare_ts_model.joblib"] | |
| # Replace with your actual username | |
| repo_id = "Arsive/flight-fare-prediction" | |
| for model in models: | |
| if not os.path.exists(model): | |
| try: | |
| print(f"Downloading {model} from Hugging Face...") | |
| hf_hub_download(repo_id=repo_id, filename=model, local_dir=".") | |
| print(f"Downloaded {model} successfully") | |
| except Exception as e: | |
| print(f"Error downloading {model}: {e}") | |
| # Continue even if download fails - prediction.py has fallbacks | |
| async def predict(request: PredictionRequest): | |
| try: | |
| # Convert month names to numbers if necessary | |
| months = ['January', 'February', 'March', 'April', 'May', 'June', | |
| 'July', 'August', 'September', 'October', 'November', 'December'] | |
| start_month = request.start_month | |
| if isinstance(start_month, str) and not start_month.isdigit(): | |
| try: | |
| start_month = months.index(start_month) + 1 | |
| except ValueError: | |
| pass | |
| end_month = request.end_month | |
| if isinstance(end_month, str) and not end_month.isdigit(): | |
| try: | |
| end_month = months.index(end_month) + 1 | |
| except ValueError: | |
| pass | |
| # Call prediction function with parameters from request | |
| result = predict_best_time_to_buy_ticket( | |
| origin=request.origin, | |
| destination=request.destination, | |
| granularity=request.granularity, | |
| future_year=request.futureYear, | |
| weeks_ahead=request.weeksAhead, | |
| start_month=start_month, | |
| end_month=end_month, | |
| carrier=request.carrier | |
| ) | |
| return result | |
| except Exception as e: | |
| # Log the error | |
| print(f"API Error: {str(e)}") | |
| print(traceback.format_exc()) | |
| # Return error response | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def root(): | |
| return {"message": "Welcome to the FlightSavvy API! Use /api/predict for predictions."} | |
| async def help(): | |
| return { | |
| "message": "This API predicts the best time to buy flight tickets.", | |
| "usage": { | |
| "POST /api/predict": { | |
| "description": "Predict the best time to buy flight tickets.", | |
| "parameters": { | |
| "origin": "Departure airport code", | |
| "destination": "Arrival airport code", | |
| "granularity": "Time granularity (e.g., 'quarter')", | |
| "futureYear": "Year for prediction", | |
| "weeksAhead": "Weeks ahead for prediction", | |
| "start_month": "Start month (1-12 or month name)", | |
| "end_month": "End month (1-12 or month name)", | |
| "carrier": "Carrier code (optional)" | |
| } | |
| } | |
| } | |
| } | |
| # If running directly, start the server | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860) |