MHamzaShahid's picture
Update app.py
419e921 verified
import os
import logging
from typing import Optional
import joblib
import pandas as pd
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, validator
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Crop Yield Prediction API",
description="Machine Learning API for predicting crop yields based on agricultural parameters",
version="2.0.0",
docs_url="/", # Swagger UI at root
redoc_url="/redoc"
)
# Add CORS middleware for cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for API access
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model variable
model = None
# Pydantic models for request/response validation
class PredictionRequest(BaseModel):
crop: str = Field(..., description="Type of crop", example="Rice")
season: str = Field(..., description="Growing season", example="Kharif")
state: str = Field(..., description="State/Region", example="Punjab")
area: float = Field(..., gt=0, description="Cultivated area in hectares", example=100.0)
annual_rainfall: float = Field(..., ge=0, description="Annual rainfall in mm", example=1200.0)
fertilizer: float = Field(..., ge=0, description="Fertilizer used in kg", example=150.0)
pesticide: float = Field(..., ge=0, description="Pesticide used in kg", example=50.0)
year: int = Field(..., ge=1900, le=2100, description="Year of cultivation", example=2024)
class Config:
json_schema_extra = {
"example": {
"crop": "Rice",
"season": "Kharif",
"state": "Punjab",
"area": 100.0,
"annual_rainfall": 1200.0,
"fertilizer": 150.0,
"pesticide": 50.0,
"year": 2024
}
}
class PredictionResponse(BaseModel):
success: bool
predicted_yield: float = Field(..., description="Predicted crop yield in tonnes/hectare")
unit: str = "tonnes/hectare"
yield_category: str
input_data: dict
message: Optional[str] = None
class HealthResponse(BaseModel):
status: str
model_loaded: bool
api_version: str
class ErrorResponse(BaseModel):
success: bool = False
error: str
detail: Optional[str] = None
# Startup event to load model
@app.on_event("startup")
async def load_model():
"""Load the machine learning model on startup"""
global model
try:
# Get the directory where the script is located
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, "crop_yield_pipeline.pkl")
logger.info(f"Looking for model at: {model_path}")
logger.info(f"Current directory: {os.getcwd()}")
logger.info(f"Files in script directory: {os.listdir(script_dir)}")
if not os.path.exists(model_path):
logger.error(f"Model file not found at {model_path}")
logger.error(f"Available files: {os.listdir(script_dir)}")
return
# Load the model
model = joblib.load(model_path)
logger.info("βœ… Model loaded successfully!")
except Exception as e:
logger.error(f"❌ Error loading model: {str(e)}")
model = None
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
global model
model = None
logger.info("πŸ”΄ Application shutdown complete")
# API Endpoints
@app.get("/", include_in_schema=False)
async def root():
"""Redirect to API documentation"""
return {
"message": "Crop Yield Prediction API",
"documentation": "/docs",
"health_check": "/health",
"prediction_endpoint": "/predict"
}
@app.get("/health", response_model=HealthResponse, tags=["Health"])
async def health_check():
"""Check API and model health status"""
return HealthResponse(
status="healthy" if model is not None else "model_not_loaded",
model_loaded=model is not None,
api_version="2.0.0"
)
@app.post("/predict", response_model=PredictionResponse, tags=["Prediction"])
async def predict_yield(request: PredictionRequest):
"""
Predict crop yield based on input parameters
Returns predicted yield in tonnes/hectare along with yield category
"""
# Check if model is loaded
if model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model not loaded. Please check server logs or contact administrator."
)
try:
# Create input DataFrame with exact column names expected by the model
input_df = pd.DataFrame({
'Crop': [request.crop],
'Season': [request.season],
'State': [request.state],
'Area': [request.area],
'Annual_Rainfall': [request.annual_rainfall],
'Fertilizer': [request.fertilizer],
'Pesticide': [request.pesticide],
'Year': [request.year]
})
logger.info(f"Making prediction for: {request.crop} in {request.state}")
# Make prediction
prediction = model.predict(input_df)[0]
# Determine yield category
if prediction < 1:
category = "Low Yield"
elif prediction < 5:
category = "Moderate Yield"
elif prediction < 50:
category = "High Yield"
else:
category = "Exceptional Yield"
# Prepare response
response = PredictionResponse(
success=True,
predicted_yield=round(float(prediction), 2),
yield_category=category,
input_data=request.dict(),
message="Prediction successful"
)
logger.info(f"βœ… Prediction: {prediction:.2f} tonnes/hectare ({category})")
return response
except Exception as e:
logger.error(f"❌ Prediction error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error during prediction: {str(e)}"
)
@app.post("/batch_predict", tags=["Prediction"])
async def batch_predict(requests: list[PredictionRequest]):
"""
Make predictions for multiple inputs at once
Accepts a list of prediction requests and returns predictions for all
"""
if model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model not loaded"
)
if len(requests) > 100:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Maximum 100 predictions per batch request"
)
try:
results = []
for req in requests:
input_df = pd.DataFrame({
'Crop': [req.crop],
'Season': [req.season],
'State': [req.state],
'Area': [req.area],
'Annual_Rainfall': [req.annual_rainfall],
'Fertilizer': [req.fertilizer],
'Pesticide': [req.pesticide],
'Year': [req.year]
})
prediction = model.predict(input_df)[0]
if prediction < 1:
category = "Low Yield"
elif prediction < 5:
category = "Moderate Yield"
elif prediction < 50:
category = "High Yield"
else:
category = "Exceptional Yield"
results.append({
"predicted_yield": round(float(prediction), 2),
"yield_category": category,
"input_data": req.dict()
})
return {
"success": True,
"count": len(results),
"predictions": results
}
except Exception as e:
logger.error(f"❌ Batch prediction error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error during batch prediction: {str(e)}"
)
# Error handlers
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
return JSONResponse(
status_code=exc.status_code,
content={"success": False, "error": exc.detail}
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc):
logger.error(f"Unhandled exception: {str(exc)}")
return JSONResponse(
status_code=500,
content={"success": False, "error": "Internal server error", "detail": str(exc)}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)