| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| import tensorflow as tf |
| import numpy as np |
| from PIL import Image |
| import io |
| import uvicorn |
| import tempfile |
| import cv2 |
|
|
| |
| app = FastAPI(title="Plant Disease Detection API", version="1.0.0") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| model = tf.keras.models.load_model('trained_modela.keras') |
|
|
| |
| class_name = ['Apple___Apple_scab', |
| 'Apple___Black_rot', |
| 'Apple___Cedar_apple_rust', |
| 'Apple___healthy', |
| 'Blueberry___healthy', |
| 'Cherry_(including_sour)___Powdery_mildew', |
| 'Cherry_(including_sour)___healthy', |
| 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', |
| 'Corn_(maize)___Common_rust_', |
| 'Corn_(maize)___Northern_Leaf_Blight', |
| 'Corn_(maize)___healthy', |
| 'Grape___Black_rot', |
| 'Grape___Esca_(Black_Measles)', |
| 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', |
| 'Grape___healthy', |
| 'Orange___Haunglongbing_(Citrus_greening)', |
| 'Peach___Bacterial_spot', |
| 'Peach___healthy', |
| 'Pepper,_bell___Bacterial_spot', |
| 'Pepper,_bell___healthy', |
| 'Potato___Early_blight', |
| 'Potato___Late_blight', |
| 'Potato___healthy', |
| 'Raspberry___healthy', |
| 'Soybean___healthy', |
| 'Squash___Powdery_mildew', |
| 'Strawberry___Leaf_scorch', |
| 'Strawberry___healthy', |
| 'Tomato___Bacterial_spot', |
| 'Tomato___Early_blight', |
| 'Tomato___Late_blight', |
| 'Tomato___Leaf_Mold', |
| 'Tomato___Septoria_leaf_spot', |
| 'Tomato___Spider_mites Two-spotted_spider_mite', |
| 'Tomato___Target_Spot', |
| 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', |
| 'Tomato___Tomato_mosaic_virus', |
| 'Tomato___healthy'] |
|
|
|
|
| @app.get("/") |
| async def root(): |
| return {"message": "Plant Disease Detection API", "version": "1.0.0"} |
|
|
| @app.post("/predict") |
| async def predict_disease(file: UploadFile = File(...)): |
| """ |
| Predict plant disease from uploaded image |
| """ |
| try: |
| |
| |
| if not file.content_type.startswith('image/'): |
| raise HTTPException(status_code=400, detail="File must be an image") |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: |
| temp_path = tmp.name |
| contents = await file.read() |
| tmp.write(contents) |
| |
| |
| img = cv2.imread(temp_path) |
| if img is None: |
| raise HTTPException(status_code=400, detail="Invalid image file") |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| |
| image = tf.keras.preprocessing.image.load_img(temp_path,target_size=(128, 128)) |
| |
| input_arr = tf.keras.preprocessing.image.img_to_array(image) |
| input_arr = np.array([input_arr]) |
|
|
| |
| prediction = model.predict(input_arr) |
| result_index = np.argmax(prediction) |
| confidence = prediction[0][result_index] |
| disease_name = class_name[result_index] |
|
|
| return { |
| "success": True, |
| "disease": disease_name, |
| "confidence": confidence |
| } |
|
|
| except HTTPException as he: |
| raise he |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") |
|
|
| @app.get("/health") |
| async def health_check(): |
| return {"status": "healthy"} |
|
|
| @app.get("/classes") |
| async def get_classes(): |
| """Get all available disease classes""" |
| return {"classes": class_name} |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=8000) |