plant-detection / app.py
AdeshJain's picture
Update app.py
dbad62f verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import hf_hub_download
from tensorflow import keras
import numpy as np
from PIL import Image
import io
from typing import Dict, Optional
import uvicorn
import os
import httpx
from groq import Groq
app = FastAPI(title="CNN Image Prediction API")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize Groq client
client = Groq(
api_key="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",
)
# Weather API key - get from openweathermap.org
WEATHER_API_KEY = "6af4193be6ef82758dde960743909a80"
# Global variable for model
model = None
IMG_SIZE = (224, 224)
# Class labels
CLASS_LABELS = [
"Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___Healthy",
"Blueberry___Healthy", "Cherry_(including_sour)___Powdery_mildew", "Cherry_(including_sour)___Healthy",
"Corn_(maize)___Cercospora_leaf_spot_Gray_leaf_spot", "Corn_(maize)__Common_rust", "Corn_(maize)___Northern_Leaf_Blight",
"Corn_(maize)___Healthy", "Grape___Black_rot", "Grape__Esca(Black_Measles)", "Grape__Leaf_blight(Isariopsis_Leaf_Spot)",
"Grape___Healthy", "Orange__Haunglongbing(Citrus_greening)", "Peach___Bacterial_spot", "Peach___Healthy",
"Pepper,bell__Bacterial_spot", "Pepper,bell__Healthy", "Potato___Early_blight", "Potato___Late_blight", "Potato___Healthy",
"Raspberry___Healthy", "Soybean___Healthy", "Squash___Powdery_mildew", "Strawberry___Leaf_scorch", "Strawberry___Healthy",
'cotton : bacterial_blight', ' cotton : curl_virus', ' fussarium_wilt', "Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot",
' fussarium_wilt', "Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot",
"Tomato___Spider_mites_Two-spotted_spider_mite", "Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
"Tomato___Tomato_mosaic_virus", "Tomato___Healthy"
]
@app.on_event("startup")
async def load_model_keras():
"""Load the Keras model from Hugging Face on startup"""
global model
try:
repo_id = "AdeshJain/plant-detection"
filename = "plant_disease_model.keras"
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir="./model_cache"
)
model = keras.models.load_model(model_path)
print(f"Model loaded successfully from {repo_id}!")
except Exception as e:
print(f"Error loading model: {e}")
raise
def preprocess_image(image: Image.Image) -> np.ndarray:
"""Preprocess the image for model prediction"""
if image.mode != 'RGB':
image = image.convert('RGB')
image = image.resize(IMG_SIZE)
img_array = np.array(image)
img_array = img_array.astype('float32') / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img_array
async def get_weather_data(lat: float, lon: float) -> Dict:
"""Fetch weather data from OpenWeather API"""
url = f"https://api.openweathermap.org/data/2.5/weather?lat={lat}&lon={lon}&appid={WEATHER_API_KEY}&units=metric"
try:
async with httpx.AsyncClient() as client_http:
response = await client_http.get(url)
if response.status_code != 200:
return None
return response.json()
except Exception as e:
print(f"Weather API error: {e}")
return None
def get_llm_remedies(disease: str, weather_data: Optional[Dict], location: str) -> Dict:
"""Get disease remedies from Groq LLM"""
# Prepare weather context
weather_context = ""
if weather_data:
temp = weather_data['main']['temp']
humidity = weather_data['main']['humidity']
weather_desc = weather_data['weather'][0]['description']
weather_context = f"""
Current Weather Conditions at {location}:
- Temperature: {temp}°C
- Humidity: {humidity}%
- Conditions: {weather_desc}
"""
else:
weather_context = f"Location: {location}\n(Weather data unavailable)"
prompt = f"""You are an expert agricultural consultant specializing in plant disease management. A farmer has detected the following plant disease through image analysis:
**Detected Disease: {disease}**
{weather_context}
Please provide comprehensive treatment recommendations in the following structured format:
## 1. CHEMICAL TREATMENT METHODS
Provide specific chemical treatments including:
- Recommended fungicides/pesticides (with active ingredients)
- Application dosage and concentration
- Application frequency and timing
- Safety precautions and protective equipment needed
- Pre-harvest intervals if applicable
## 2. SUSTAINABLE & NATURAL TREATMENT METHODS
Provide organic and eco-friendly solutions including:
- Natural/organic sprays and remedies
- Biological control methods
- Cultural practices and preventive measures
- Soil management techniques
- Plant-based solutions
## 3. WEATHER-SPECIFIC RECOMMENDATIONS
Based on the current weather conditions:
- How weather affects disease progression
- Best time to apply treatments
- Additional precautions needed
- Environmental considerations
## 4. PREVENTIVE MEASURES
Long-term strategies to prevent recurrence:
- Crop rotation suggestions
- Irrigation management
- Nutrient management
- Monitoring practices
Please be specific, practical, and actionable. Consider both immediate treatment and long-term disease management."""
try:
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are an expert agricultural consultant with deep knowledge of plant pathology and sustainable farming practices. Provide detailed, practical advice for farmers."
},
{
"role": "user",
"content": prompt
}
],
model="llama-3.3-70b-versatile",
temperature=0.7,
max_tokens=500
)
return {
"remedies": chat_completion.choices[0].message.content,
"model_used": "llama-3.3-70b-versatile"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"LLM error: {str(e)}")
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"message": "CNN Image Prediction API is running",
"status": "healthy",
"model_loaded": model is not None
}
# @app.get("/test")
# async def test_prediction():
# """Test endpoint using a hardcoded image from local directory"""
# if model is None:
# raise HTTPException(status_code=500, detail="Model not loaded")
# test_image_path = "test2.jpg"
# if not os.path.exists(test_image_path):
# raise HTTPException(
# status_code=404,
# detail=f"Test image not found at {test_image_path}. Please place a test image in the directory."
# )
# try:
# image = Image.open(test_image_path)
# processed_img = preprocess_image(image)
# prediction = model.predict(processed_img)
# predicted_class_idx = int(np.argmax(prediction, axis=1)[0])
# confidence = float(np.max(prediction))
# predicted_class_name = CLASS_LABELS[predicted_class_idx]
# top_5_indices = np.argsort(prediction[0])[-5:][::-1]
# top_5_predictions = [
# {
# "class": CLASS_LABELS[idx],
# "confidence": float(prediction[0][idx])
# }
# for idx in top_5_indices
# ]
# print("Top-5 indices & confidences:", top_5_indices)
# return {
# "top_prediction": {
# "class": predicted_class_name,
# "confidence": confidence
# }
# }
# except Exception as e:
# raise HTTPException(status_code=500, detail=f"Test prediction error: {str(e)}")
@app.post("/predict")
async def predict(file: UploadFile = File(...)) -> Dict:
"""Predict image class using the loaded CNN model"""
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
processed_img = preprocess_image(image)
prediction = model.predict(processed_img)
predicted_class_idx = int(np.argmax(prediction, axis=1)[0])
confidence = float(np.max(prediction))
predicted_class_name = CLASS_LABELS[predicted_class_idx]
top_5_indices = np.argsort(prediction[0])[-5:][::-1]
top_5_predictions = [
{
"class": CLASS_LABELS[idx],
"confidence": float(prediction[0][idx])
}
for idx in top_5_indices
]
return {
"success": True,
"predicted_class": predicted_class_name,
"predicted_class_index": predicted_class_idx,
"confidence": confidence,
"top_5_predictions": top_5_predictions,
"filename": file.filename
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
@app.post("/predict-with-remedies")
async def predict_with_remedies(
file: UploadFile = File(...),
latitude: float = Form(...),
longitude: float = Form(...),
city: Optional[str] = Form(None)
) -> Dict:
"""
Predict plant disease and provide AI-generated remedies with weather context
Args:
file: Uploaded plant image
latitude: Location latitude for weather data
longitude: Location longitude for weather data
city: Optional city name for display
Returns:
Dictionary containing prediction, weather data, and treatment recommendations
"""
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
try:
# Step 1: Make disease prediction using CNN model
contents = await file.read()
image = Image.open(io.BytesIO(contents))
processed_img = preprocess_image(image)
prediction = model.predict(processed_img)
predicted_class_idx = int(np.argmax(prediction, axis=1)[0])
confidence = float(np.max(prediction))
predicted_class_name = CLASS_LABELS[predicted_class_idx]
# Get top 3 predictions for additional context
top_3_indices = np.argsort(prediction[0])[-3:][::-1]
top_3_predictions = [
{
"class": CLASS_LABELS[idx],
"confidence": float(prediction[0][idx])
}
for idx in top_3_indices
]
# Step 2: Fetch weather data for the location
location_name = city if city else f"Lat: {latitude}, Lon: {longitude}"
weather_data = await get_weather_data(latitude, longitude)
# Step 3: Get LLM-generated remedies
# Only generate remedies if disease detected (not healthy)
remedies_data = None
if "Healthy" not in predicted_class_name:
remedies_data = get_llm_remedies(predicted_class_name, weather_data, location_name)
else:
remedies_data = {
"remedies": "🎉 Great news! Your plant appears to be healthy. No treatment needed.\n\n**Preventive Care Tips:**\n- Continue regular monitoring\n- Maintain proper watering schedule\n- Ensure adequate sunlight\n- Keep the area clean and weed-free\n- Monitor for any changes in plant appearance",
"model_used": "rule-based"
}
# Prepare weather info for response
weather_info = None
if weather_data:
weather_info = {
"temperature": weather_data['main']['temp'],
"feels_like": weather_data['main']['feels_like'],
"humidity": weather_data['main']['humidity'],
"pressure": weather_data['main']['pressure'],
"conditions": weather_data['weather'][0]['description'],
"wind_speed": weather_data['wind']['speed']
}
return {
"success": True,
"prediction": {
"disease": predicted_class_name,
"confidence": confidence,
"is_healthy": "Healthy" in predicted_class_name,
"top_3_predictions": top_3_predictions
},
"location": {
"name": location_name,
"latitude": latitude,
"longitude": longitude
},
"weather": weather_info,
"treatment": {
"remedies": remedies_data["remedies"],
"llm_model": remedies_data["model_used"]
},
"filename": file.filename
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)