Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from huggingface_hub import hf_hub_download | |
| from tensorflow import keras | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| from typing import Dict, Optional | |
| import uvicorn | |
| import os | |
| import httpx | |
| from groq import Groq | |
| app = FastAPI(title="CNN Image Prediction API") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize Groq client | |
| client = Groq( | |
| api_key="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", | |
| ) | |
| # Weather API key - get from openweathermap.org | |
| WEATHER_API_KEY = "6af4193be6ef82758dde960743909a80" | |
| # Global variable for model | |
| model = None | |
| IMG_SIZE = (224, 224) | |
| # Class labels | |
| CLASS_LABELS = [ | |
| "Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___Healthy", | |
| "Blueberry___Healthy", "Cherry_(including_sour)___Powdery_mildew", "Cherry_(including_sour)___Healthy", | |
| "Corn_(maize)___Cercospora_leaf_spot_Gray_leaf_spot", "Corn_(maize)__Common_rust", "Corn_(maize)___Northern_Leaf_Blight", | |
| "Corn_(maize)___Healthy", "Grape___Black_rot", "Grape__Esca(Black_Measles)", "Grape__Leaf_blight(Isariopsis_Leaf_Spot)", | |
| "Grape___Healthy", "Orange__Haunglongbing(Citrus_greening)", "Peach___Bacterial_spot", "Peach___Healthy", | |
| "Pepper,bell__Bacterial_spot", "Pepper,bell__Healthy", "Potato___Early_blight", "Potato___Late_blight", "Potato___Healthy", | |
| "Raspberry___Healthy", "Soybean___Healthy", "Squash___Powdery_mildew", "Strawberry___Leaf_scorch", "Strawberry___Healthy", | |
| 'cotton : bacterial_blight', ' cotton : curl_virus', ' fussarium_wilt', "Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot", | |
| ' fussarium_wilt', "Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot", | |
| "Tomato___Spider_mites_Two-spotted_spider_mite", "Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus", | |
| "Tomato___Tomato_mosaic_virus", "Tomato___Healthy" | |
| ] | |
| async def load_model_keras(): | |
| """Load the Keras model from Hugging Face on startup""" | |
| global model | |
| try: | |
| repo_id = "AdeshJain/plant-detection" | |
| filename = "plant_disease_model.keras" | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| cache_dir="./model_cache" | |
| ) | |
| model = keras.models.load_model(model_path) | |
| print(f"Model loaded successfully from {repo_id}!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| def preprocess_image(image: Image.Image) -> np.ndarray: | |
| """Preprocess the image for model prediction""" | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image = image.resize(IMG_SIZE) | |
| img_array = np.array(image) | |
| img_array = img_array.astype('float32') / 255.0 | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array | |
| async def get_weather_data(lat: float, lon: float) -> Dict: | |
| """Fetch weather data from OpenWeather API""" | |
| url = f"https://api.openweathermap.org/data/2.5/weather?lat={lat}&lon={lon}&appid={WEATHER_API_KEY}&units=metric" | |
| try: | |
| async with httpx.AsyncClient() as client_http: | |
| response = await client_http.get(url) | |
| if response.status_code != 200: | |
| return None | |
| return response.json() | |
| except Exception as e: | |
| print(f"Weather API error: {e}") | |
| return None | |
| def get_llm_remedies(disease: str, weather_data: Optional[Dict], location: str) -> Dict: | |
| """Get disease remedies from Groq LLM""" | |
| # Prepare weather context | |
| weather_context = "" | |
| if weather_data: | |
| temp = weather_data['main']['temp'] | |
| humidity = weather_data['main']['humidity'] | |
| weather_desc = weather_data['weather'][0]['description'] | |
| weather_context = f""" | |
| Current Weather Conditions at {location}: | |
| - Temperature: {temp}°C | |
| - Humidity: {humidity}% | |
| - Conditions: {weather_desc} | |
| """ | |
| else: | |
| weather_context = f"Location: {location}\n(Weather data unavailable)" | |
| prompt = f"""You are an expert agricultural consultant specializing in plant disease management. A farmer has detected the following plant disease through image analysis: | |
| **Detected Disease: {disease}** | |
| {weather_context} | |
| Please provide comprehensive treatment recommendations in the following structured format: | |
| ## 1. CHEMICAL TREATMENT METHODS | |
| Provide specific chemical treatments including: | |
| - Recommended fungicides/pesticides (with active ingredients) | |
| - Application dosage and concentration | |
| - Application frequency and timing | |
| - Safety precautions and protective equipment needed | |
| - Pre-harvest intervals if applicable | |
| ## 2. SUSTAINABLE & NATURAL TREATMENT METHODS | |
| Provide organic and eco-friendly solutions including: | |
| - Natural/organic sprays and remedies | |
| - Biological control methods | |
| - Cultural practices and preventive measures | |
| - Soil management techniques | |
| - Plant-based solutions | |
| ## 3. WEATHER-SPECIFIC RECOMMENDATIONS | |
| Based on the current weather conditions: | |
| - How weather affects disease progression | |
| - Best time to apply treatments | |
| - Additional precautions needed | |
| - Environmental considerations | |
| ## 4. PREVENTIVE MEASURES | |
| Long-term strategies to prevent recurrence: | |
| - Crop rotation suggestions | |
| - Irrigation management | |
| - Nutrient management | |
| - Monitoring practices | |
| Please be specific, practical, and actionable. Consider both immediate treatment and long-term disease management.""" | |
| try: | |
| chat_completion = client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are an expert agricultural consultant with deep knowledge of plant pathology and sustainable farming practices. Provide detailed, practical advice for farmers." | |
| }, | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ], | |
| model="llama-3.3-70b-versatile", | |
| temperature=0.7, | |
| max_tokens=500 | |
| ) | |
| return { | |
| "remedies": chat_completion.choices[0].message.content, | |
| "model_used": "llama-3.3-70b-versatile" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"LLM error: {str(e)}") | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "message": "CNN Image Prediction API is running", | |
| "status": "healthy", | |
| "model_loaded": model is not None | |
| } | |
| # @app.get("/test") | |
| # async def test_prediction(): | |
| # """Test endpoint using a hardcoded image from local directory""" | |
| # if model is None: | |
| # raise HTTPException(status_code=500, detail="Model not loaded") | |
| # test_image_path = "test2.jpg" | |
| # if not os.path.exists(test_image_path): | |
| # raise HTTPException( | |
| # status_code=404, | |
| # detail=f"Test image not found at {test_image_path}. Please place a test image in the directory." | |
| # ) | |
| # try: | |
| # image = Image.open(test_image_path) | |
| # processed_img = preprocess_image(image) | |
| # prediction = model.predict(processed_img) | |
| # predicted_class_idx = int(np.argmax(prediction, axis=1)[0]) | |
| # confidence = float(np.max(prediction)) | |
| # predicted_class_name = CLASS_LABELS[predicted_class_idx] | |
| # top_5_indices = np.argsort(prediction[0])[-5:][::-1] | |
| # top_5_predictions = [ | |
| # { | |
| # "class": CLASS_LABELS[idx], | |
| # "confidence": float(prediction[0][idx]) | |
| # } | |
| # for idx in top_5_indices | |
| # ] | |
| # print("Top-5 indices & confidences:", top_5_indices) | |
| # return { | |
| # "top_prediction": { | |
| # "class": predicted_class_name, | |
| # "confidence": confidence | |
| # } | |
| # } | |
| # except Exception as e: | |
| # raise HTTPException(status_code=500, detail=f"Test prediction error: {str(e)}") | |
| async def predict(file: UploadFile = File(...)) -> Dict: | |
| """Predict image class using the loaded CNN model""" | |
| if model is None: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| processed_img = preprocess_image(image) | |
| prediction = model.predict(processed_img) | |
| predicted_class_idx = int(np.argmax(prediction, axis=1)[0]) | |
| confidence = float(np.max(prediction)) | |
| predicted_class_name = CLASS_LABELS[predicted_class_idx] | |
| top_5_indices = np.argsort(prediction[0])[-5:][::-1] | |
| top_5_predictions = [ | |
| { | |
| "class": CLASS_LABELS[idx], | |
| "confidence": float(prediction[0][idx]) | |
| } | |
| for idx in top_5_indices | |
| ] | |
| return { | |
| "success": True, | |
| "predicted_class": predicted_class_name, | |
| "predicted_class_index": predicted_class_idx, | |
| "confidence": confidence, | |
| "top_5_predictions": top_5_predictions, | |
| "filename": file.filename | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") | |
| async def predict_with_remedies( | |
| file: UploadFile = File(...), | |
| latitude: float = Form(...), | |
| longitude: float = Form(...), | |
| city: Optional[str] = Form(None) | |
| ) -> Dict: | |
| """ | |
| Predict plant disease and provide AI-generated remedies with weather context | |
| Args: | |
| file: Uploaded plant image | |
| latitude: Location latitude for weather data | |
| longitude: Location longitude for weather data | |
| city: Optional city name for display | |
| Returns: | |
| Dictionary containing prediction, weather data, and treatment recommendations | |
| """ | |
| if model is None: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| try: | |
| # Step 1: Make disease prediction using CNN model | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| processed_img = preprocess_image(image) | |
| prediction = model.predict(processed_img) | |
| predicted_class_idx = int(np.argmax(prediction, axis=1)[0]) | |
| confidence = float(np.max(prediction)) | |
| predicted_class_name = CLASS_LABELS[predicted_class_idx] | |
| # Get top 3 predictions for additional context | |
| top_3_indices = np.argsort(prediction[0])[-3:][::-1] | |
| top_3_predictions = [ | |
| { | |
| "class": CLASS_LABELS[idx], | |
| "confidence": float(prediction[0][idx]) | |
| } | |
| for idx in top_3_indices | |
| ] | |
| # Step 2: Fetch weather data for the location | |
| location_name = city if city else f"Lat: {latitude}, Lon: {longitude}" | |
| weather_data = await get_weather_data(latitude, longitude) | |
| # Step 3: Get LLM-generated remedies | |
| # Only generate remedies if disease detected (not healthy) | |
| remedies_data = None | |
| if "Healthy" not in predicted_class_name: | |
| remedies_data = get_llm_remedies(predicted_class_name, weather_data, location_name) | |
| else: | |
| remedies_data = { | |
| "remedies": "🎉 Great news! Your plant appears to be healthy. No treatment needed.\n\n**Preventive Care Tips:**\n- Continue regular monitoring\n- Maintain proper watering schedule\n- Ensure adequate sunlight\n- Keep the area clean and weed-free\n- Monitor for any changes in plant appearance", | |
| "model_used": "rule-based" | |
| } | |
| # Prepare weather info for response | |
| weather_info = None | |
| if weather_data: | |
| weather_info = { | |
| "temperature": weather_data['main']['temp'], | |
| "feels_like": weather_data['main']['feels_like'], | |
| "humidity": weather_data['main']['humidity'], | |
| "pressure": weather_data['main']['pressure'], | |
| "conditions": weather_data['weather'][0]['description'], | |
| "wind_speed": weather_data['wind']['speed'] | |
| } | |
| return { | |
| "success": True, | |
| "prediction": { | |
| "disease": predicted_class_name, | |
| "confidence": confidence, | |
| "is_healthy": "Healthy" in predicted_class_name, | |
| "top_3_predictions": top_3_predictions | |
| }, | |
| "location": { | |
| "name": location_name, | |
| "latitude": latitude, | |
| "longitude": longitude | |
| }, | |
| "weather": weather_info, | |
| "treatment": { | |
| "remedies": remedies_data["remedies"], | |
| "llm_model": remedies_data["model_used"] | |
| }, | |
| "filename": file.filename | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |