Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import joblib | |
| import logging | |
| from typing import Dict, Any | |
| import numpy as np | |
| import pandas as pd | |
| from fastapi import FastAPI, HTTPException, status | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field, field_validator | |
| from contextlib import asynccontextmanager | |
| from src.entity.config_entity import DataTransformationConfig, TrainingPipelineConfig | |
| from glob import glob | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Global variables for model artifacts | |
| model_artifacts = {} | |
| def find_latest_artifacts(): | |
| """Find the most recent artifacts directory""" | |
| try: | |
| artifacts_base = "Artifacts" | |
| if not os.path.exists(artifacts_base): | |
| return None, None | |
| # Get all artifact directories sorted by timestamp (most recent first) | |
| artifact_dirs = sorted( | |
| [d for d in os.listdir(artifacts_base) if os.path.isdir(os.path.join(artifacts_base, d))], | |
| reverse=True | |
| ) | |
| if not artifact_dirs: | |
| return None, None | |
| latest_dir = artifact_dirs[0] | |
| logger.info(f"Found latest artifacts directory: {latest_dir}") | |
| # Look for the pipeline file (try common names) | |
| pipeline_dir = os.path.join(artifacts_base, latest_dir, "data_transformation", "transformed_object") | |
| if os.path.exists(pipeline_dir): | |
| # Try to find the pipeline file with various possible names | |
| possible_names = ["preprocessing.pkl", "preprocessor.pkl", "pipeline.pkl", "transformed_object.pkl"] | |
| for name in possible_names: | |
| pipeline_path = os.path.join(pipeline_dir, name) | |
| if os.path.exists(pipeline_path): | |
| logger.info(f"Found pipeline at: {pipeline_path}") | |
| return pipeline_path, None | |
| # If no standard name found, get any .pkl file | |
| pkl_files = glob(os.path.join(pipeline_dir, "*.pkl")) | |
| if pkl_files: | |
| logger.info(f"Found pipeline file: {pkl_files[0]}") | |
| return pkl_files[0], None | |
| return None, None | |
| except Exception as e: | |
| logger.error(f"Error finding artifacts: {e}") | |
| return None, None | |
| # Try to use config first, fall back to finding latest artifacts | |
| try: | |
| training_pipeline_config = TrainingPipelineConfig() | |
| data_transformation_config = DataTransformationConfig(training_pipeline_config) | |
| PIPELINE_PATH = data_transformation_config.transformed_object_file_path | |
| MODEL_PATH = "final_model/model.pkl" | |
| # Check if the config path exists | |
| if not os.path.exists(PIPELINE_PATH): | |
| logger.warning(f"Config path not found: {PIPELINE_PATH}") | |
| logger.info("Attempting to find latest artifacts...") | |
| found_pipeline, _ = find_latest_artifacts() | |
| if found_pipeline: | |
| PIPELINE_PATH = found_pipeline | |
| logger.info(f"Using found pipeline: {PIPELINE_PATH}") | |
| except Exception as e: | |
| logger.warning(f"Could not load config: {e}") | |
| logger.info("Attempting to find latest artifacts...") | |
| PIPELINE_PATH, _ = find_latest_artifacts() | |
| MODEL_PATH = "final_model/model.pkl" | |
| async def lifespan(app: FastAPI): | |
| """Load model artifacts on startup and cleanup on shutdown""" | |
| try: | |
| logger.info("Loading model artifacts...") | |
| logger.info(f"Pipeline path: {PIPELINE_PATH}") | |
| logger.info(f"Model path: {MODEL_PATH}") | |
| if not PIPELINE_PATH or not os.path.exists(PIPELINE_PATH): | |
| error_msg = f"Pipeline not found at {PIPELINE_PATH}\n" | |
| error_msg += "\nPlease ensure you have run the training pipeline first.\n" | |
| error_msg += "Training creates the pipeline at: Artifacts/<timestamp>/data_transformation/transformed_object/\n" | |
| # List what files exist | |
| if PIPELINE_PATH: | |
| parent_dir = os.path.dirname(PIPELINE_PATH) | |
| if os.path.exists(parent_dir): | |
| files = os.listdir(parent_dir) | |
| error_msg += f"\nFiles found in {parent_dir}:\n" | |
| for f in files: | |
| error_msg += f" - {f}\n" | |
| raise FileNotFoundError(error_msg) | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Model not found at {MODEL_PATH}\n" | |
| "Please ensure the model has been trained and saved to final_model/model.pkl") | |
| # Load the complete preprocessing pipeline | |
| logger.info(f"Loading pipeline from: {PIPELINE_PATH}") | |
| model_artifacts['pipeline'] = joblib.load(PIPELINE_PATH) | |
| # Load the trained model | |
| logger.info(f"Loading model from: {MODEL_PATH}") | |
| model_artifacts['model'] = joblib.load(MODEL_PATH) | |
| logger.info("✓ Model artifacts loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model artifacts: {str(e)}") | |
| raise | |
| yield | |
| # Cleanup | |
| logger.info("Shutting down application...") | |
| model_artifacts.clear() | |
| # Initialize FastAPI app with lifespan | |
| app = FastAPI( | |
| title="ETA Prediction API", | |
| version="1.0.0", | |
| description="API for predicting delivery time based on various features", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Input schema with ORIGINAL feature names | |
| class DeliveryInput(BaseModel): | |
| Distance_km: float = Field(..., gt=0, example=10.5, description="Distance in kilometers") | |
| Courier_Experience_yrs: float = Field(..., ge=0, example=2.0, description="Courier experience in years") | |
| Vehicle_Type: str = Field(..., example="Scooter", description="Vehicle type") | |
| Weather: str = Field(..., example="Sunny", description="Weather condition") | |
| Time_of_Day: str = Field(..., example="Morning", description="Time of day") | |
| Traffic_Level: str = Field(..., example="Low", description="Traffic level") | |
| def validate_distance(cls, v): | |
| if v > 1000: | |
| raise ValueError('Distance cannot exceed 1000 km') | |
| return v | |
| def validate_experience(cls, v): | |
| if v > 50: | |
| raise ValueError('Experience cannot exceed 50 years') | |
| return v | |
| def validate_vehicle(cls, v): | |
| allowed = ['Scooter', 'Pickup Truck', 'Motorcycle'] | |
| if v not in allowed: | |
| raise ValueError(f'Vehicle type must be one of {allowed}') | |
| return v | |
| def validate_weather(cls, v): | |
| allowed = ['Sunny', 'Rainy', 'Foggy', 'Snowy', 'Windy'] | |
| if v not in allowed: | |
| raise ValueError(f'Weather must be one of {allowed}') | |
| return v | |
| def validate_time(cls, v): | |
| allowed = ['Morning', 'Afternoon', 'Evening', 'Night'] | |
| if v not in allowed: | |
| raise ValueError(f'Time of day must be one of {allowed}') | |
| return v | |
| def validate_traffic(cls, v): | |
| allowed = ['Low', 'Medium', 'High'] | |
| if v not in allowed: | |
| raise ValueError(f'Traffic level must be one of {allowed}') | |
| return v | |
| class Config: | |
| schema_extra = { | |
| "example": { | |
| "Distance_km": 10.5, | |
| "Courier_Experience_yrs": 2.0, | |
| "Vehicle_Type": "Scooter", | |
| "Weather": "Sunny", | |
| "Time_of_Day": "Morning", | |
| "Traffic_Level": "Low" | |
| } | |
| } | |
| class PredictionResponse(BaseModel): | |
| predicted_delivery_time: float | |
| input_features: Dict[str, Any] | |
| model_version: str | |
| class HealthResponse(BaseModel): | |
| status: str | |
| message: str | |
| model_loaded: bool | |
| def health_check(): | |
| model_loaded = ( | |
| 'pipeline' in model_artifacts | |
| and model_artifacts['pipeline'] is not None | |
| and 'model' in model_artifacts | |
| and model_artifacts['model'] is not None | |
| ) | |
| return HealthResponse( | |
| status="healthy" if model_loaded else "unhealthy", | |
| message="ETA Prediction API is running" if model_loaded else "Model not loaded", | |
| model_loaded=model_loaded | |
| ) | |
| def predict_eta(input_data: DeliveryInput): | |
| try: | |
| if 'pipeline' not in model_artifacts or 'model' not in model_artifacts: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Pipeline or model not loaded" | |
| ) | |
| input_dict = input_data.model_dump() | |
| logger.info(f"Prediction request: {input_dict}") | |
| # Create DataFrame and transform | |
| input_df = pd.DataFrame([input_dict]) | |
| transformed_input = model_artifacts['pipeline'].transform(input_df) | |
| # Predict | |
| prediction = model_artifacts['model'].predict(transformed_input) | |
| predicted_time = float(prediction[0]) | |
| logger.info(f"Prediction: {predicted_time:.2f} minutes") | |
| return PredictionResponse( | |
| predicted_delivery_time=round(predicted_time, 2), | |
| input_features=input_dict, | |
| model_version="1.0.0" | |
| ) | |
| except ValueError as ve: | |
| logger.error(f"Validation error: {str(ve)}") | |
| raise HTTPException(status_code=400, detail=str(ve)) | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def predict_batch(input_data: list[DeliveryInput]): | |
| try: | |
| if 'pipeline' not in model_artifacts or 'model' not in model_artifacts: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| predictions = [] | |
| for idx, data in enumerate(input_data): | |
| input_dict = data.model_dump() | |
| input_df = pd.DataFrame([input_dict]) | |
| transformed = model_artifacts['pipeline'].transform(input_df) | |
| pred = model_artifacts['model'].predict(transformed) | |
| predictions.append({ | |
| "index": idx, | |
| "predicted_delivery_time": round(float(pred[0]), 2), | |
| "input_features": input_dict | |
| }) | |
| return {"predictions": predictions, "count": len(predictions)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_model_info(): | |
| if 'pipeline' not in model_artifacts or 'model' not in model_artifacts: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| return { | |
| "model_version": "1.0.0", | |
| "model_type": type(model_artifacts['model']).__name__, | |
| "pipeline_type": type(model_artifacts['pipeline']).__name__ | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |