|
|
import os |
|
|
import logging |
|
|
from typing import Optional |
|
|
import joblib |
|
|
import pandas as pd |
|
|
from fastapi import FastAPI, HTTPException, status |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
from pydantic import BaseModel, Field, validator |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Crop Yield Prediction API", |
|
|
description="Machine Learning API for predicting crop yields based on agricultural parameters", |
|
|
version="2.0.0", |
|
|
docs_url="/", |
|
|
redoc_url="/redoc" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
model = None |
|
|
|
|
|
|
|
|
class PredictionRequest(BaseModel): |
|
|
crop: str = Field(..., description="Type of crop", example="Rice") |
|
|
season: str = Field(..., description="Growing season", example="Kharif") |
|
|
state: str = Field(..., description="State/Region", example="Punjab") |
|
|
area: float = Field(..., gt=0, description="Cultivated area in hectares", example=100.0) |
|
|
annual_rainfall: float = Field(..., ge=0, description="Annual rainfall in mm", example=1200.0) |
|
|
fertilizer: float = Field(..., ge=0, description="Fertilizer used in kg", example=150.0) |
|
|
pesticide: float = Field(..., ge=0, description="Pesticide used in kg", example=50.0) |
|
|
year: int = Field(..., ge=1900, le=2100, description="Year of cultivation", example=2024) |
|
|
|
|
|
class Config: |
|
|
json_schema_extra = { |
|
|
"example": { |
|
|
"crop": "Rice", |
|
|
"season": "Kharif", |
|
|
"state": "Punjab", |
|
|
"area": 100.0, |
|
|
"annual_rainfall": 1200.0, |
|
|
"fertilizer": 150.0, |
|
|
"pesticide": 50.0, |
|
|
"year": 2024 |
|
|
} |
|
|
} |
|
|
|
|
|
class PredictionResponse(BaseModel): |
|
|
success: bool |
|
|
predicted_yield: float = Field(..., description="Predicted crop yield in tonnes/hectare") |
|
|
unit: str = "tonnes/hectare" |
|
|
yield_category: str |
|
|
input_data: dict |
|
|
message: Optional[str] = None |
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
model_loaded: bool |
|
|
api_version: str |
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
|
success: bool = False |
|
|
error: str |
|
|
detail: Optional[str] = None |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def load_model(): |
|
|
"""Load the machine learning model on startup""" |
|
|
global model |
|
|
|
|
|
try: |
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
model_path = os.path.join(script_dir, "crop_yield_pipeline.pkl") |
|
|
|
|
|
logger.info(f"Looking for model at: {model_path}") |
|
|
logger.info(f"Current directory: {os.getcwd()}") |
|
|
logger.info(f"Files in script directory: {os.listdir(script_dir)}") |
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
logger.error(f"Model file not found at {model_path}") |
|
|
logger.error(f"Available files: {os.listdir(script_dir)}") |
|
|
return |
|
|
|
|
|
|
|
|
model = joblib.load(model_path) |
|
|
logger.info("β
Model loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Error loading model: {str(e)}") |
|
|
model = None |
|
|
|
|
|
@app.on_event("shutdown") |
|
|
async def shutdown_event(): |
|
|
"""Cleanup on shutdown""" |
|
|
global model |
|
|
model = None |
|
|
logger.info("π΄ Application shutdown complete") |
|
|
|
|
|
|
|
|
@app.get("/", include_in_schema=False) |
|
|
async def root(): |
|
|
"""Redirect to API documentation""" |
|
|
return { |
|
|
"message": "Crop Yield Prediction API", |
|
|
"documentation": "/docs", |
|
|
"health_check": "/health", |
|
|
"prediction_endpoint": "/predict" |
|
|
} |
|
|
|
|
|
@app.get("/health", response_model=HealthResponse, tags=["Health"]) |
|
|
async def health_check(): |
|
|
"""Check API and model health status""" |
|
|
return HealthResponse( |
|
|
status="healthy" if model is not None else "model_not_loaded", |
|
|
model_loaded=model is not None, |
|
|
api_version="2.0.0" |
|
|
) |
|
|
|
|
|
@app.post("/predict", response_model=PredictionResponse, tags=["Prediction"]) |
|
|
async def predict_yield(request: PredictionRequest): |
|
|
""" |
|
|
Predict crop yield based on input parameters |
|
|
|
|
|
Returns predicted yield in tonnes/hectare along with yield category |
|
|
""" |
|
|
|
|
|
|
|
|
if model is None: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
|
|
detail="Model not loaded. Please check server logs or contact administrator." |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
input_df = pd.DataFrame({ |
|
|
'Crop': [request.crop], |
|
|
'Season': [request.season], |
|
|
'State': [request.state], |
|
|
'Area': [request.area], |
|
|
'Annual_Rainfall': [request.annual_rainfall], |
|
|
'Fertilizer': [request.fertilizer], |
|
|
'Pesticide': [request.pesticide], |
|
|
'Year': [request.year] |
|
|
}) |
|
|
|
|
|
logger.info(f"Making prediction for: {request.crop} in {request.state}") |
|
|
|
|
|
|
|
|
prediction = model.predict(input_df)[0] |
|
|
|
|
|
|
|
|
if prediction < 1: |
|
|
category = "Low Yield" |
|
|
elif prediction < 5: |
|
|
category = "Moderate Yield" |
|
|
elif prediction < 50: |
|
|
category = "High Yield" |
|
|
else: |
|
|
category = "Exceptional Yield" |
|
|
|
|
|
|
|
|
response = PredictionResponse( |
|
|
success=True, |
|
|
predicted_yield=round(float(prediction), 2), |
|
|
yield_category=category, |
|
|
input_data=request.dict(), |
|
|
message="Prediction successful" |
|
|
) |
|
|
|
|
|
logger.info(f"β
Prediction: {prediction:.2f} tonnes/hectare ({category})") |
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Prediction error: {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
|
detail=f"Error during prediction: {str(e)}" |
|
|
) |
|
|
|
|
|
@app.post("/batch_predict", tags=["Prediction"]) |
|
|
async def batch_predict(requests: list[PredictionRequest]): |
|
|
""" |
|
|
Make predictions for multiple inputs at once |
|
|
|
|
|
Accepts a list of prediction requests and returns predictions for all |
|
|
""" |
|
|
|
|
|
if model is None: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
|
|
detail="Model not loaded" |
|
|
) |
|
|
|
|
|
if len(requests) > 100: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
|
detail="Maximum 100 predictions per batch request" |
|
|
) |
|
|
|
|
|
try: |
|
|
results = [] |
|
|
for req in requests: |
|
|
input_df = pd.DataFrame({ |
|
|
'Crop': [req.crop], |
|
|
'Season': [req.season], |
|
|
'State': [req.state], |
|
|
'Area': [req.area], |
|
|
'Annual_Rainfall': [req.annual_rainfall], |
|
|
'Fertilizer': [req.fertilizer], |
|
|
'Pesticide': [req.pesticide], |
|
|
'Year': [req.year] |
|
|
}) |
|
|
|
|
|
prediction = model.predict(input_df)[0] |
|
|
|
|
|
if prediction < 1: |
|
|
category = "Low Yield" |
|
|
elif prediction < 5: |
|
|
category = "Moderate Yield" |
|
|
elif prediction < 50: |
|
|
category = "High Yield" |
|
|
else: |
|
|
category = "Exceptional Yield" |
|
|
|
|
|
results.append({ |
|
|
"predicted_yield": round(float(prediction), 2), |
|
|
"yield_category": category, |
|
|
"input_data": req.dict() |
|
|
}) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"count": len(results), |
|
|
"predictions": results |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Batch prediction error: {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
|
detail=f"Error during batch prediction: {str(e)}" |
|
|
) |
|
|
|
|
|
|
|
|
@app.exception_handler(HTTPException) |
|
|
async def http_exception_handler(request, exc): |
|
|
return JSONResponse( |
|
|
status_code=exc.status_code, |
|
|
content={"success": False, "error": exc.detail} |
|
|
) |
|
|
|
|
|
@app.exception_handler(Exception) |
|
|
async def general_exception_handler(request, exc): |
|
|
logger.error(f"Unhandled exception: {str(exc)}") |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"success": False, "error": "Internal server error", "detail": str(exc)} |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |