""" FastAPI App for Crop Disease Detection RESTful API replacement for Streamlit - Deployment-ready for Hugging Face Spaces """ from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Query from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, List, Dict, Any import torch import torch.nn.functional as F import numpy as np from PIL import Image import io import json import sys import os import uuid import tempfile import asyncio from pathlib import Path # Set matplotlib backend before importing pyplot (fixes headless environment) import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import base64 from datetime import datetime import time # Add src to path for imports sys.path.append('src') try: from src.model import CropDiseaseResNet50 from src.explain import CropDiseaseExplainer from src.risk_level import RiskLevelCalculator from torchvision import transforms except ImportError as e: print(f"Import error: {e}") raise e # FastAPI app configuration app = FastAPI( title="🌱 Crop Disease AI Detection API", description="RESTful API for AI-powered crop disease detection with Grad-CAM visualization", version="3.0.0", docs_url="/docs", redoc_url="/redoc" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables for model and processing status model = None device = None explainer = None risk_calculator = None processing_status = {} class_names = [] # Model classes (from V3 model) DEFAULT_CLASSES = [ 'Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___healthy', 'Potato___Late_blight', 'Tomato__Target_Spot', 'Tomato__Tomato_mosaic_virus', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_healthy', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite' ] # Pydantic models for API responses class HealthResponse(BaseModel): status: str ai_model_loaded: bool ai_model_version: str available_endpoints: List[str] timestamp: str device: str class PredictionResponse(BaseModel): success: bool predicted_class: str crop: str disease: str confidence: float all_probabilities: Dict[str, float] risk_level: str processing_time: float task_id: str class GradCAMResponse(BaseModel): success: bool heatmap_base64: str explanation: str task_id: str processing_time: float class StatusResponse(BaseModel): task_id: str status: str progress: int message: str timestamp: str class WeatherData(BaseModel): humidity: Optional[float] = 50.0 temperature: Optional[float] = 25.0 rainfall: Optional[float] = 0.0 class PredictionRequest(BaseModel): weather_data: Optional[WeatherData] = None include_gradcam: Optional[bool] = True include_disease_info: Optional[bool] = True async def load_model_on_startup(): """Load the trained model on startup""" global model, device, explainer, risk_calculator, class_names try: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"🔧 Using device: {device}") # Try V3 model first, fallback to V2 model_paths = [ 'models/crop_disease_v3_model.pth', 'models/crop_disease_v2_model.pth' ] model = None model_name = None for model_path in model_paths: if os.path.exists(model_path): try: model = CropDiseaseResNet50(num_classes=len(DEFAULT_CLASSES), pretrained=False) checkpoint = torch.load(model_path, map_location=device) # Handle different checkpoint formats if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] else: state_dict = checkpoint model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() model_name = os.path.basename(model_path) break except Exception as e: print(f"Failed to load {model_path}: {e}") continue if model is None: print("❌ No valid model found!") raise RuntimeError("No valid model found!") # Initialize explainer and risk calculator try: explainer = CropDiseaseExplainer(model, DEFAULT_CLASSES, device) risk_calculator = RiskLevelCalculator() except Exception as e: print(f"Failed to initialize explainer: {e}") explainer = None risk_calculator = None class_names = DEFAULT_CLASSES print(f"✅ Model loaded: {model_name}") return True except Exception as e: print(f"Error loading model: {e}") return False def preprocess_image(image): """Preprocess image for model input""" transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0) def predict_disease(model, device, image_tensor): """Make disease prediction""" with torch.no_grad(): outputs = model(image_tensor.to(device)) probabilities = F.softmax(outputs, dim=1) confidence, predicted_idx = torch.max(probabilities, 1) predicted_class = DEFAULT_CLASSES[predicted_idx.item()] confidence_score = confidence.item() # Get all class probabilities class_probabilities = { DEFAULT_CLASSES[i]: probabilities[0, i].item() for i in range(len(DEFAULT_CLASSES)) } return predicted_class, confidence_score, class_probabilities def parse_class_name(class_name): """Parse crop and disease from class name""" if '___' in class_name: parts = class_name.split('___') crop = parts[0] disease = parts[1] elif '__' in class_name: parts = class_name.split('__', 1) crop = parts[0] disease = parts[1] elif '_' in class_name: parts = class_name.split('_', 1) crop = parts[0] disease = parts[1] else: crop = "Unknown" disease = class_name return crop, disease def get_disease_info(crop, disease): """Get disease information from knowledge base""" try: with open('knowledge_base/disease_info.json', 'r') as f: kb_data = json.load(f) for d in kb_data['diseases']: if crop.lower() in d['crop'].lower() and disease.lower() in d['disease'].lower(): return d except Exception: pass return None def update_processing_status(task_id: str, status: str, progress: int, message: str): """Update processing status for a task""" processing_status[task_id] = { "status": status, "progress": progress, "message": message, "timestamp": datetime.now().isoformat() } # FastAPI Events @app.on_event("startup") async def startup_event(): """Initialize model on startup""" print("🚀 Starting Crop Disease Detection API...") await load_model_on_startup() print("✅ API ready to serve requests!") # API Endpoints @app.get("/", response_model=Dict[str, Any]) async def root(): """Root endpoint with API information""" return { "message": "🌱 Crop Disease Detection API", "version": "3.0.0", "status": "running", "docs": "/docs", "endpoints": { "health": "/health", "predict": "/predict", "gradcam": "/gradcam/{task_id}", "status": "/status/{task_id}" } } @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" global model, device ai_model_loaded = model is not None device_str = str(device) if device else "unknown" ai_model_version = "crop_disease_v3_model.pth" if ai_model_loaded else "not_loaded" return HealthResponse( status="healthy" if ai_model_loaded else "unhealthy", ai_model_loaded=ai_model_loaded, ai_model_version=ai_model_version, available_endpoints=["/health", "/predict", "/gradcam/{task_id}", "/status/{task_id}"], timestamp=datetime.now().isoformat(), device=device_str ) @app.post("/predict", response_model=PredictionResponse) async def predict_crop_disease( background_tasks: BackgroundTasks, file: UploadFile = File(...), weather_data: Optional[str] = Query(None, description="JSON string of weather data"), include_gradcam: bool = Query(True, description="Generate Grad-CAM heatmap"), include_disease_info: bool = Query(True, description="Include disease information") ): """ Predict crop disease from uploaded image """ global model, device, risk_calculator if model is None: raise HTTPException(status_code=503, detail="Model not loaded") # Validate file type if file.content_type not in ["image/jpeg", "image/jpg", "image/png", "image/bmp"]: raise HTTPException(status_code=400, detail="Invalid file type. Only JPEG, PNG, and BMP are supported.") task_id = str(uuid.uuid4()) start_time = time.time() try: # Update status: Image uploaded update_processing_status(task_id, "processing", 10, "Image uploaded successfully") # Read and process image image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert('RGB') # Update status: Preprocessing update_processing_status(task_id, "processing", 30, "Preprocessing image") # Preprocess image image_tensor = preprocess_image(image) # Update status: Model running update_processing_status(task_id, "processing", 50, "Running inference") # Make prediction predicted_class, confidence_score, class_probabilities = predict_disease( model, device, image_tensor ) # Parse class name crop, disease = parse_class_name(predicted_class) # Update status: Risk assessment update_processing_status(task_id, "processing", 70, "Calculating risk assessment") # Calculate risk level risk_level = "Unknown" if risk_calculator: try: weather = {} if weather_data: weather = json.loads(weather_data) weather_data_obj = { 'humidity': weather.get('humidity', 50.0), 'temperature': weather.get('temperature', 25.0), 'rainfall': weather.get('rainfall', 0.0) } risk_assessment = risk_calculator.calculate_enhanced_risk( predicted_class, confidence_score, weather_data_obj, None ) risk_level = risk_assessment.get('risk_level', 'Unknown') except Exception as e: print(f"Risk assessment error: {e}") # Update status: Completed update_processing_status(task_id, "completed", 100, "Analysis completed successfully") processing_time = time.time() - start_time # Schedule Grad-CAM generation if requested if include_gradcam and explainer: background_tasks.add_task(generate_gradcam_background, task_id, image_bytes) return PredictionResponse( success=True, predicted_class=predicted_class, crop=crop, disease=disease, confidence=confidence_score, all_probabilities=class_probabilities, risk_level=risk_level, processing_time=processing_time, task_id=task_id ) except Exception as e: update_processing_status(task_id, "error", 0, f"Error: {str(e)}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") async def generate_gradcam_background(task_id: str, image_bytes: bytes): """Generate Grad-CAM heatmap in background""" global explainer try: update_processing_status(task_id, "processing", 80, "Generating Grad-CAM heatmap") # Save temporary image with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file: tmp_file.write(image_bytes) temp_path = tmp_file.name try: # Generate explanation explanation = explainer.explain_prediction(temp_path, return_base64=True) if 'overlay_base64' in explanation: # Store the result processing_status[f"{task_id}_gradcam"] = { "success": True, "heatmap_base64": explanation['overlay_base64'], "explanation": "Grad-CAM heatmap showing areas the AI model focused on for prediction", "timestamp": datetime.now().isoformat() } else: error_msg = explanation.get('error', 'Unknown error generating Grad-CAM') processing_status[f"{task_id}_gradcam"] = { "success": False, "error": error_msg, "timestamp": datetime.now().isoformat() } finally: # Clean up temp file if os.path.exists(temp_path): os.unlink(temp_path) except Exception as e: processing_status[f"{task_id}_gradcam"] = { "success": False, "error": str(e), "timestamp": datetime.now().isoformat() } @app.get("/gradcam/{task_id}", response_model=GradCAMResponse) async def get_gradcam(task_id: str): """Get Grad-CAM heatmap for a prediction task""" gradcam_key = f"{task_id}_gradcam" if gradcam_key not in processing_status: raise HTTPException(status_code=404, detail="Grad-CAM not found or still processing") result = processing_status[gradcam_key] if not result.get("success", False): raise HTTPException(status_code=500, detail=f"Grad-CAM generation failed: {result.get('error', 'Unknown error')}") return GradCAMResponse( success=True, heatmap_base64=result["heatmap_base64"], explanation=result["explanation"], task_id=task_id, processing_time=0.0 # Background task, time not tracked ) @app.get("/status/{task_id}", response_model=StatusResponse) async def get_status(task_id: str): """Get processing status for a task""" if task_id not in processing_status: raise HTTPException(status_code=404, detail="Task not found") status = processing_status[task_id] return StatusResponse( task_id=task_id, status=status["status"], progress=status["progress"], message=status["message"], timestamp=status["timestamp"] ) @app.get("/disease-info") async def get_disease_information(crop: str, disease: str): """Get disease information from knowledge base""" disease_info = get_disease_info(crop, disease) if disease_info: return {"success": True, "data": disease_info} else: return {"success": False, "message": "Disease information not found"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="localhost", port=7860)