Spaces:
Running
Running
| """ | |
| FastAPI App for Crop Disease Detection | |
| RESTful API replacement for Streamlit - Deployment-ready for Hugging Face Spaces | |
| """ | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Query | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Dict, Any | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import json | |
| import sys | |
| import os | |
| import uuid | |
| import tempfile | |
| import asyncio | |
| from pathlib import Path | |
| # Set matplotlib backend before importing pyplot (fixes headless environment) | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import base64 | |
| from datetime import datetime | |
| import time | |
| # Add src to path for imports | |
| sys.path.append('src') | |
| try: | |
| from src.model import CropDiseaseResNet50 | |
| from src.explain import CropDiseaseExplainer | |
| from src.risk_level import RiskLevelCalculator | |
| from torchvision import transforms | |
| except ImportError as e: | |
| print(f"Import error: {e}") | |
| raise e | |
| # FastAPI app configuration | |
| app = FastAPI( | |
| title="🌱 Crop Disease AI Detection API", | |
| description="RESTful API for AI-powered crop disease detection with Grad-CAM visualization", | |
| version="3.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for model and processing status | |
| model = None | |
| device = None | |
| explainer = None | |
| risk_calculator = None | |
| processing_status = {} | |
| class_names = [] | |
| # Model classes (from V3 model) | |
| DEFAULT_CLASSES = [ | |
| 'Pepper__bell___Bacterial_spot', | |
| 'Pepper__bell___healthy', | |
| 'Potato___Early_blight', | |
| 'Potato___healthy', | |
| 'Potato___Late_blight', | |
| 'Tomato__Target_Spot', | |
| 'Tomato__Tomato_mosaic_virus', | |
| 'Tomato__Tomato_YellowLeaf__Curl_Virus', | |
| 'Tomato_Bacterial_spot', | |
| 'Tomato_Early_blight', | |
| 'Tomato_healthy', | |
| 'Tomato_Late_blight', | |
| 'Tomato_Leaf_Mold', | |
| 'Tomato_Septoria_leaf_spot', | |
| 'Tomato_Spider_mites_Two_spotted_spider_mite' | |
| ] | |
| # Pydantic models for API responses | |
| class HealthResponse(BaseModel): | |
| status: str | |
| ai_model_loaded: bool | |
| ai_model_version: str | |
| available_endpoints: List[str] | |
| timestamp: str | |
| device: str | |
| class PredictionResponse(BaseModel): | |
| success: bool | |
| predicted_class: str | |
| crop: str | |
| disease: str | |
| confidence: float | |
| all_probabilities: Dict[str, float] | |
| risk_level: str | |
| processing_time: float | |
| task_id: str | |
| class GradCAMResponse(BaseModel): | |
| success: bool | |
| heatmap_base64: str | |
| explanation: str | |
| task_id: str | |
| processing_time: float | |
| class StatusResponse(BaseModel): | |
| task_id: str | |
| status: str | |
| progress: int | |
| message: str | |
| timestamp: str | |
| class WeatherData(BaseModel): | |
| humidity: Optional[float] = 50.0 | |
| temperature: Optional[float] = 25.0 | |
| rainfall: Optional[float] = 0.0 | |
| class PredictionRequest(BaseModel): | |
| weather_data: Optional[WeatherData] = None | |
| include_gradcam: Optional[bool] = True | |
| include_disease_info: Optional[bool] = True | |
| async def load_model_on_startup(): | |
| """Load the trained model on startup""" | |
| global model, device, explainer, risk_calculator, class_names | |
| try: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"🔧 Using device: {device}") | |
| # Try V3 model first, fallback to V2 | |
| model_paths = [ | |
| 'models/crop_disease_v3_model.pth', | |
| 'models/crop_disease_v2_model.pth' | |
| ] | |
| model = None | |
| model_name = None | |
| for model_path in model_paths: | |
| if os.path.exists(model_path): | |
| try: | |
| model = CropDiseaseResNet50(num_classes=len(DEFAULT_CLASSES), pretrained=False) | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # Handle different checkpoint formats | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| model.load_state_dict(state_dict, strict=True) | |
| model.to(device) | |
| model.eval() | |
| model_name = os.path.basename(model_path) | |
| break | |
| except Exception as e: | |
| print(f"Failed to load {model_path}: {e}") | |
| continue | |
| if model is None: | |
| print("❌ No valid model found!") | |
| raise RuntimeError("No valid model found!") | |
| # Initialize explainer and risk calculator | |
| try: | |
| explainer = CropDiseaseExplainer(model, DEFAULT_CLASSES, device) | |
| risk_calculator = RiskLevelCalculator() | |
| except Exception as e: | |
| print(f"Failed to initialize explainer: {e}") | |
| explainer = None | |
| risk_calculator = None | |
| class_names = DEFAULT_CLASSES | |
| print(f"✅ Model loaded: {model_name}") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return False | |
| def preprocess_image(image): | |
| """Preprocess image for model input""" | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return transform(image).unsqueeze(0) | |
| def predict_disease(model, device, image_tensor): | |
| """Make disease prediction""" | |
| with torch.no_grad(): | |
| outputs = model(image_tensor.to(device)) | |
| probabilities = F.softmax(outputs, dim=1) | |
| confidence, predicted_idx = torch.max(probabilities, 1) | |
| predicted_class = DEFAULT_CLASSES[predicted_idx.item()] | |
| confidence_score = confidence.item() | |
| # Get all class probabilities | |
| class_probabilities = { | |
| DEFAULT_CLASSES[i]: probabilities[0, i].item() | |
| for i in range(len(DEFAULT_CLASSES)) | |
| } | |
| return predicted_class, confidence_score, class_probabilities | |
| def parse_class_name(class_name): | |
| """Parse crop and disease from class name""" | |
| if '___' in class_name: | |
| parts = class_name.split('___') | |
| crop = parts[0] | |
| disease = parts[1] | |
| elif '__' in class_name: | |
| parts = class_name.split('__', 1) | |
| crop = parts[0] | |
| disease = parts[1] | |
| elif '_' in class_name: | |
| parts = class_name.split('_', 1) | |
| crop = parts[0] | |
| disease = parts[1] | |
| else: | |
| crop = "Unknown" | |
| disease = class_name | |
| return crop, disease | |
| def get_disease_info(crop, disease): | |
| """Get disease information from knowledge base""" | |
| try: | |
| with open('knowledge_base/disease_info.json', 'r') as f: | |
| kb_data = json.load(f) | |
| for d in kb_data['diseases']: | |
| if crop.lower() in d['crop'].lower() and disease.lower() in d['disease'].lower(): | |
| return d | |
| except Exception: | |
| pass | |
| return None | |
| def update_processing_status(task_id: str, status: str, progress: int, message: str): | |
| """Update processing status for a task""" | |
| processing_status[task_id] = { | |
| "status": status, | |
| "progress": progress, | |
| "message": message, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # FastAPI Events | |
| async def startup_event(): | |
| """Initialize model on startup""" | |
| print("🚀 Starting Crop Disease Detection API...") | |
| await load_model_on_startup() | |
| print("✅ API ready to serve requests!") | |
| # API Endpoints | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "🌱 Crop Disease Detection API", | |
| "version": "3.0.0", | |
| "status": "running", | |
| "docs": "/docs", | |
| "endpoints": { | |
| "health": "/health", | |
| "predict": "/predict", | |
| "gradcam": "/gradcam/{task_id}", | |
| "status": "/status/{task_id}" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| global model, device | |
| ai_model_loaded = model is not None | |
| device_str = str(device) if device else "unknown" | |
| ai_model_version = "crop_disease_v3_model.pth" if ai_model_loaded else "not_loaded" | |
| return HealthResponse( | |
| status="healthy" if ai_model_loaded else "unhealthy", | |
| ai_model_loaded=ai_model_loaded, | |
| ai_model_version=ai_model_version, | |
| available_endpoints=["/health", "/predict", "/gradcam/{task_id}", "/status/{task_id}"], | |
| timestamp=datetime.now().isoformat(), | |
| device=device_str | |
| ) | |
| async def predict_crop_disease( | |
| background_tasks: BackgroundTasks, | |
| file: UploadFile = File(...), | |
| weather_data: Optional[str] = Query(None, description="JSON string of weather data"), | |
| include_gradcam: bool = Query(True, description="Generate Grad-CAM heatmap"), | |
| include_disease_info: bool = Query(True, description="Include disease information") | |
| ): | |
| """ | |
| Predict crop disease from uploaded image | |
| """ | |
| global model, device, risk_calculator | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| # Validate file type | |
| if file.content_type not in ["image/jpeg", "image/jpg", "image/png", "image/bmp"]: | |
| raise HTTPException(status_code=400, detail="Invalid file type. Only JPEG, PNG, and BMP are supported.") | |
| task_id = str(uuid.uuid4()) | |
| start_time = time.time() | |
| try: | |
| # Update status: Image uploaded | |
| update_processing_status(task_id, "processing", 10, "Image uploaded successfully") | |
| # Read and process image | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
| # Update status: Preprocessing | |
| update_processing_status(task_id, "processing", 30, "Preprocessing image") | |
| # Preprocess image | |
| image_tensor = preprocess_image(image) | |
| # Update status: Model running | |
| update_processing_status(task_id, "processing", 50, "Running inference") | |
| # Make prediction | |
| predicted_class, confidence_score, class_probabilities = predict_disease( | |
| model, device, image_tensor | |
| ) | |
| # Parse class name | |
| crop, disease = parse_class_name(predicted_class) | |
| # Update status: Risk assessment | |
| update_processing_status(task_id, "processing", 70, "Calculating risk assessment") | |
| # Calculate risk level | |
| risk_level = "Unknown" | |
| if risk_calculator: | |
| try: | |
| weather = {} | |
| if weather_data: | |
| weather = json.loads(weather_data) | |
| weather_data_obj = { | |
| 'humidity': weather.get('humidity', 50.0), | |
| 'temperature': weather.get('temperature', 25.0), | |
| 'rainfall': weather.get('rainfall', 0.0) | |
| } | |
| risk_assessment = risk_calculator.calculate_enhanced_risk( | |
| predicted_class, confidence_score, weather_data_obj, None | |
| ) | |
| risk_level = risk_assessment.get('risk_level', 'Unknown') | |
| except Exception as e: | |
| print(f"Risk assessment error: {e}") | |
| # Update status: Completed | |
| update_processing_status(task_id, "completed", 100, "Analysis completed successfully") | |
| processing_time = time.time() - start_time | |
| # Schedule Grad-CAM generation if requested | |
| if include_gradcam and explainer: | |
| background_tasks.add_task(generate_gradcam_background, task_id, image_bytes) | |
| return PredictionResponse( | |
| success=True, | |
| predicted_class=predicted_class, | |
| crop=crop, | |
| disease=disease, | |
| confidence=confidence_score, | |
| all_probabilities=class_probabilities, | |
| risk_level=risk_level, | |
| processing_time=processing_time, | |
| task_id=task_id | |
| ) | |
| except Exception as e: | |
| update_processing_status(task_id, "error", 0, f"Error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| async def generate_gradcam_background(task_id: str, image_bytes: bytes): | |
| """Generate Grad-CAM heatmap in background""" | |
| global explainer | |
| try: | |
| update_processing_status(task_id, "processing", 80, "Generating Grad-CAM heatmap") | |
| # Save temporary image | |
| with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file: | |
| tmp_file.write(image_bytes) | |
| temp_path = tmp_file.name | |
| try: | |
| # Generate explanation | |
| explanation = explainer.explain_prediction(temp_path, return_base64=True) | |
| if 'overlay_base64' in explanation: | |
| # Store the result | |
| processing_status[f"{task_id}_gradcam"] = { | |
| "success": True, | |
| "heatmap_base64": explanation['overlay_base64'], | |
| "explanation": "Grad-CAM heatmap showing areas the AI model focused on for prediction", | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| else: | |
| error_msg = explanation.get('error', 'Unknown error generating Grad-CAM') | |
| processing_status[f"{task_id}_gradcam"] = { | |
| "success": False, | |
| "error": error_msg, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| finally: | |
| # Clean up temp file | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) | |
| except Exception as e: | |
| processing_status[f"{task_id}_gradcam"] = { | |
| "success": False, | |
| "error": str(e), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def get_gradcam(task_id: str): | |
| """Get Grad-CAM heatmap for a prediction task""" | |
| gradcam_key = f"{task_id}_gradcam" | |
| if gradcam_key not in processing_status: | |
| raise HTTPException(status_code=404, detail="Grad-CAM not found or still processing") | |
| result = processing_status[gradcam_key] | |
| if not result.get("success", False): | |
| raise HTTPException(status_code=500, detail=f"Grad-CAM generation failed: {result.get('error', 'Unknown error')}") | |
| return GradCAMResponse( | |
| success=True, | |
| heatmap_base64=result["heatmap_base64"], | |
| explanation=result["explanation"], | |
| task_id=task_id, | |
| processing_time=0.0 # Background task, time not tracked | |
| ) | |
| async def get_status(task_id: str): | |
| """Get processing status for a task""" | |
| if task_id not in processing_status: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| status = processing_status[task_id] | |
| return StatusResponse( | |
| task_id=task_id, | |
| status=status["status"], | |
| progress=status["progress"], | |
| message=status["message"], | |
| timestamp=status["timestamp"] | |
| ) | |
| async def get_disease_information(crop: str, disease: str): | |
| """Get disease information from knowledge base""" | |
| disease_info = get_disease_info(crop, disease) | |
| if disease_info: | |
| return {"success": True, "data": disease_info} | |
| else: | |
| return {"success": False, "message": "Disease information not found"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="localhost", port=7860) |