crop-prediction / app.py
zidea21's picture
Rename main.py to app.py
c87bc99 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import pandas as pd
import joblib
import os
import logging
import numpy as np
from typing import List, Dict
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Crop Recommendation API")
# Define model file paths
MODEL_PATH = r"final_model.joblib"
ENCODER_PATH = r"label_encoder.joblib"
# Load model and encoder at startup
try:
final_RF = joblib.load(MODEL_PATH)
label_encoder = joblib.load(ENCODER_PATH)
VALID_CROPS = list(label_encoder.classes_)
logger.info("Model and encoder loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or encoder: {str(e)}")
raise Exception(f"Failed to load model or encoder: {str(e)}")
# Pydantic model for input validation
class CropInput(BaseModel):
N: float = Field(..., ge=0, le=200, description="Nitrogen content in soil (kg/ha)")
P: float = Field(
..., ge=0, le=200, description="Phosphorus content in soil (kg/ha)"
)
K: float = Field(..., ge=0, le=200, description="Potassium content in soil (kg/ha)")
temperature: float = Field(..., ge=0, le=50, description="Temperature in Celsius")
ph: float = Field(..., ge=0, le=14, description="Soil pH value")
rainfall: float = Field(..., ge=0, le=2000, description="Rainfall in millimeters")
def get_top_n_classes(
model, X_df: pd.DataFrame, label_encoder, n: int = 5
) -> List[Dict]:
"""Get the top N predicted classes with their probabilities."""
try:
probs = model.predict_proba(X_df)[0]
top_indices = np.argsort(probs)[::-1][:n]
top_probs = probs[top_indices]
top_labels = label_encoder.inverse_transform(top_indices)
return [
{"crop": label, "probability": float(prob)}
for label, prob in zip(top_labels, top_probs)
]
except Exception as e:
logger.error(f"Error in get_top_n_classes: {str(e)}")
raise ValueError(f"Failed to compute top classes: {str(e)}")
# Synchronous prediction function
def predict_crop(input_data: Dict) -> Dict:
try:
# Convert input to DataFrame
input_df = pd.DataFrame([input_data])
# Validate required columns
required_cols = ["N", "P", "K", "temperature", "ph", "rainfall"]
missing_cols = set(required_cols) - set(input_df.columns)
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
# Get top 5 predictions
top_predictions = get_top_n_classes(final_RF, input_df, label_encoder, n=5)
return {"predictions": top_predictions, "status": "success"}
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
return {"predictions": [], "status": "failure", "error": str(e)}
@app.post("/predict_crop")
async def predict_crop_endpoint(input_data: CropInput):
try:
# Check if files exist
for path in [MODEL_PATH, ENCODER_PATH]:
if not os.path.exists(path):
raise HTTPException(status_code=500, detail=f"File not found: {path}")
# Convert Pydantic model to dict
input_dict = input_data.dict()
# Make prediction
result = predict_crop(input_dict)
if result["status"] == "failure":
raise HTTPException(status_code=400, detail=result["error"])
return result
except Exception as e:
logger.error(f"Error processing prediction: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error processing prediction: {str(e)}"
)
@app.get("/")
async def root():
return {
"message": "Crop Recommendation API is running. Use /predict_crop endpoint to send input data."
}
@app.get("/valid_inputs")
async def get_valid_inputs():
return {
"N": {"min": 0, "max": 200, "unit": "kg/ha"},
"P": {"min": 0, "max": 200, "unit": "kg/ha"},
"K": {"min": 0, "max": 200, "unit": "kg/ha"},
"temperature": {"min": 0, "max": 50, "unit": "Celsius"},
"ph": {"min": 0, "max": 14, "unit": "pH"},
"rainfall": {"min": 0, "max": 2000, "unit": "mm"},
"possible_crops": VALID_CROPS,
}
# conda activate crop_new
# uvicorn crop_recommender:app --host 0.0.0.0 --port 8004