import os import logging from typing import Optional import joblib import pandas as pd from fastapi import FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, validator # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Crop Yield Prediction API", description="Machine Learning API for predicting crop yields based on agricultural parameters", version="2.0.0", docs_url="/", # Swagger UI at root redoc_url="/redoc" ) # Add CORS middleware for cross-origin requests app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins for API access allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global model variable model = None # Pydantic models for request/response validation class PredictionRequest(BaseModel): crop: str = Field(..., description="Type of crop", example="Rice") season: str = Field(..., description="Growing season", example="Kharif") state: str = Field(..., description="State/Region", example="Punjab") area: float = Field(..., gt=0, description="Cultivated area in hectares", example=100.0) annual_rainfall: float = Field(..., ge=0, description="Annual rainfall in mm", example=1200.0) fertilizer: float = Field(..., ge=0, description="Fertilizer used in kg", example=150.0) pesticide: float = Field(..., ge=0, description="Pesticide used in kg", example=50.0) year: int = Field(..., ge=1900, le=2100, description="Year of cultivation", example=2024) class Config: json_schema_extra = { "example": { "crop": "Rice", "season": "Kharif", "state": "Punjab", "area": 100.0, "annual_rainfall": 1200.0, "fertilizer": 150.0, "pesticide": 50.0, "year": 2024 } } class PredictionResponse(BaseModel): success: bool predicted_yield: float = Field(..., description="Predicted crop yield in tonnes/hectare") unit: str = "tonnes/hectare" yield_category: str input_data: dict message: Optional[str] = None class HealthResponse(BaseModel): status: str model_loaded: bool api_version: str class ErrorResponse(BaseModel): success: bool = False error: str detail: Optional[str] = None # Startup event to load model @app.on_event("startup") async def load_model(): """Load the machine learning model on startup""" global model try: # Get the directory where the script is located script_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(script_dir, "crop_yield_pipeline.pkl") logger.info(f"Looking for model at: {model_path}") logger.info(f"Current directory: {os.getcwd()}") logger.info(f"Files in script directory: {os.listdir(script_dir)}") if not os.path.exists(model_path): logger.error(f"Model file not found at {model_path}") logger.error(f"Available files: {os.listdir(script_dir)}") return # Load the model model = joblib.load(model_path) logger.info("✅ Model loaded successfully!") except Exception as e: logger.error(f"❌ Error loading model: {str(e)}") model = None @app.on_event("shutdown") async def shutdown_event(): """Cleanup on shutdown""" global model model = None logger.info("🔴 Application shutdown complete") # API Endpoints @app.get("/", include_in_schema=False) async def root(): """Redirect to API documentation""" return { "message": "Crop Yield Prediction API", "documentation": "/docs", "health_check": "/health", "prediction_endpoint": "/predict" } @app.get("/health", response_model=HealthResponse, tags=["Health"]) async def health_check(): """Check API and model health status""" return HealthResponse( status="healthy" if model is not None else "model_not_loaded", model_loaded=model is not None, api_version="2.0.0" ) @app.post("/predict", response_model=PredictionResponse, tags=["Prediction"]) async def predict_yield(request: PredictionRequest): """ Predict crop yield based on input parameters Returns predicted yield in tonnes/hectare along with yield category """ # Check if model is loaded if model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model not loaded. Please check server logs or contact administrator." ) try: # Create input DataFrame with exact column names expected by the model input_df = pd.DataFrame({ 'Crop': [request.crop], 'Season': [request.season], 'State': [request.state], 'Area': [request.area], 'Annual_Rainfall': [request.annual_rainfall], 'Fertilizer': [request.fertilizer], 'Pesticide': [request.pesticide], 'Year': [request.year] }) logger.info(f"Making prediction for: {request.crop} in {request.state}") # Make prediction prediction = model.predict(input_df)[0] # Determine yield category if prediction < 1: category = "Low Yield" elif prediction < 5: category = "Moderate Yield" elif prediction < 50: category = "High Yield" else: category = "Exceptional Yield" # Prepare response response = PredictionResponse( success=True, predicted_yield=round(float(prediction), 2), yield_category=category, input_data=request.dict(), message="Prediction successful" ) logger.info(f"✅ Prediction: {prediction:.2f} tonnes/hectare ({category})") return response except Exception as e: logger.error(f"❌ Prediction error: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error during prediction: {str(e)}" ) @app.post("/batch_predict", tags=["Prediction"]) async def batch_predict(requests: list[PredictionRequest]): """ Make predictions for multiple inputs at once Accepts a list of prediction requests and returns predictions for all """ if model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model not loaded" ) if len(requests) > 100: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Maximum 100 predictions per batch request" ) try: results = [] for req in requests: input_df = pd.DataFrame({ 'Crop': [req.crop], 'Season': [req.season], 'State': [req.state], 'Area': [req.area], 'Annual_Rainfall': [req.annual_rainfall], 'Fertilizer': [req.fertilizer], 'Pesticide': [req.pesticide], 'Year': [req.year] }) prediction = model.predict(input_df)[0] if prediction < 1: category = "Low Yield" elif prediction < 5: category = "Moderate Yield" elif prediction < 50: category = "High Yield" else: category = "Exceptional Yield" results.append({ "predicted_yield": round(float(prediction), 2), "yield_category": category, "input_data": req.dict() }) return { "success": True, "count": len(results), "predictions": results } except Exception as e: logger.error(f"❌ Batch prediction error: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error during batch prediction: {str(e)}" ) # Error handlers @app.exception_handler(HTTPException) async def http_exception_handler(request, exc): return JSONResponse( status_code=exc.status_code, content={"success": False, "error": exc.detail} ) @app.exception_handler(Exception) async def general_exception_handler(request, exc): logger.error(f"Unhandled exception: {str(exc)}") return JSONResponse( status_code=500, content={"success": False, "error": "Internal server error", "detail": str(exc)} ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)