File size: 4,491 Bytes
4f6b316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import pandas as pd
import joblib
import os
import logging
import numpy as np
from typing import List, Dict

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="Crop Recommendation API")

# Define model file paths
MODEL_PATH = r"final_model.joblib"
ENCODER_PATH = r"label_encoder.joblib"

# Load model and encoder at startup
try:
    final_RF = joblib.load(MODEL_PATH)
    label_encoder = joblib.load(ENCODER_PATH)
    VALID_CROPS = list(label_encoder.classes_)
    logger.info("Model and encoder loaded successfully")
except Exception as e:
    logger.error(f"Failed to load model or encoder: {str(e)}")
    raise Exception(f"Failed to load model or encoder: {str(e)}")


# Pydantic model for input validation
class CropInput(BaseModel):
    N: float = Field(..., ge=0, le=200, description="Nitrogen content in soil (kg/ha)")
    P: float = Field(
        ..., ge=0, le=200, description="Phosphorus content in soil (kg/ha)"
    )
    K: float = Field(..., ge=0, le=200, description="Potassium content in soil (kg/ha)")
    temperature: float = Field(..., ge=0, le=50, description="Temperature in Celsius")
    ph: float = Field(..., ge=0, le=14, description="Soil pH value")
    rainfall: float = Field(..., ge=0, le=2000, description="Rainfall in millimeters")


def get_top_n_classes(

    model, X_df: pd.DataFrame, label_encoder, n: int = 5

) -> List[Dict]:
    """Get the top N predicted classes with their probabilities."""
    try:
        probs = model.predict_proba(X_df)[0]
        top_indices = np.argsort(probs)[::-1][:n]
        top_probs = probs[top_indices]
        top_labels = label_encoder.inverse_transform(top_indices)
        return [
            {"crop": label, "probability": float(prob)}
            for label, prob in zip(top_labels, top_probs)
        ]
    except Exception as e:
        logger.error(f"Error in get_top_n_classes: {str(e)}")
        raise ValueError(f"Failed to compute top classes: {str(e)}")


# Synchronous prediction function
def predict_crop(input_data: Dict) -> Dict:
    try:
        # Convert input to DataFrame
        input_df = pd.DataFrame([input_data])

        # Validate required columns
        required_cols = ["N", "P", "K", "temperature", "ph", "rainfall"]
        missing_cols = set(required_cols) - set(input_df.columns)
        if missing_cols:
            raise ValueError(f"Missing required columns: {missing_cols}")

        # Get top 5 predictions
        top_predictions = get_top_n_classes(final_RF, input_df, label_encoder, n=5)

        return {"predictions": top_predictions, "status": "success"}
    except Exception as e:
        logger.error(f"Prediction error: {str(e)}")
        return {"predictions": [], "status": "failure", "error": str(e)}


@app.post("/predict_crop")
async def predict_crop_endpoint(input_data: CropInput):
    try:
        # Check if files exist
        for path in [MODEL_PATH, ENCODER_PATH]:
            if not os.path.exists(path):
                raise HTTPException(status_code=500, detail=f"File not found: {path}")

        # Convert Pydantic model to dict
        input_dict = input_data.dict()

        # Make prediction
        result = predict_crop(input_dict)

        if result["status"] == "failure":
            raise HTTPException(status_code=400, detail=result["error"])

        return result

    except Exception as e:
        logger.error(f"Error processing prediction: {str(e)}")
        raise HTTPException(
            status_code=500, detail=f"Error processing prediction: {str(e)}"
        )


@app.get("/")
async def root():
    return {
        "message": "Crop Recommendation API is running. Use /predict_crop endpoint to send input data."
    }


@app.get("/valid_inputs")
async def get_valid_inputs():
    return {
        "N": {"min": 0, "max": 200, "unit": "kg/ha"},
        "P": {"min": 0, "max": 200, "unit": "kg/ha"},
        "K": {"min": 0, "max": 200, "unit": "kg/ha"},
        "temperature": {"min": 0, "max": 50, "unit": "Celsius"},
        "ph": {"min": 0, "max": 14, "unit": "pH"},
        "rainfall": {"min": 0, "max": 2000, "unit": "mm"},
        "possible_crops": VALID_CROPS,
    }


# conda activate crop_new
# uvicorn crop_recommender:app --host 0.0.0.0 --port 8004