Spaces:
Sleeping
Sleeping
| """ | |
| Memory-optimized FastAPI Backend for Crop Disease Detection | |
| Optimized to use <512MB RAM | |
| """ | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import io | |
| import json | |
| import sys | |
| import os | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any | |
| import tempfile | |
| import traceback | |
| import gc | |
| import psutil | |
| # Add src to path for imports | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src')) | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | |
| try: | |
| from src.model import CropDiseaseResNet50Lite | |
| from src.explain_lite import CropDiseaseExplainerLite | |
| from src.risk_level import RiskLevelCalculator | |
| from src.dataset import get_inference_transforms | |
| except ImportError as e: | |
| print(f"Import error: {e}") | |
| print("Make sure all required modules are available") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Crop Disease Detection API (Optimized)", | |
| description="Memory-optimized AI-powered crop disease detection", | |
| version="2.1.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for model and components | |
| model = None | |
| explainer = None | |
| risk_calculator = None | |
| class_names = [] | |
| device = None | |
| transforms = None | |
| def get_memory_usage(): | |
| """Get current memory usage in MB""" | |
| process = psutil.Process(os.getpid()) | |
| memory_info = process.memory_info() | |
| return memory_info.rss / 1024 / 1024 # Convert to MB | |
| def optimize_memory(): | |
| """Force garbage collection and clear GPU cache""" | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def load_model_and_components(): | |
| """Load trained model and initialize components with memory optimization""" | |
| global model, explainer, risk_calculator, class_names, device, transforms | |
| try: | |
| # Set device - prefer CPU for memory efficiency | |
| if torch.cuda.is_available() and torch.cuda.get_device_properties(0).total_memory > 2e9: | |
| device = torch.device('cuda') | |
| else: | |
| device = torch.device('cpu') | |
| print(f"Using device: {device}") | |
| # Optimized class names (reduced set for memory efficiency) | |
| class_names = [ | |
| 'Pepper_Bacterial_spot', | |
| 'Pepper_healthy', | |
| 'Potato_Early_blight', | |
| 'Potato_healthy', | |
| 'Potato_Late_blight', | |
| 'Tomato_Target_Spot', | |
| 'Tomato_mosaic_virus', | |
| 'Tomato_Yellow_Leaf_Curl', | |
| 'Tomato_Bacterial_spot', | |
| 'Tomato_Early_blight', | |
| 'Tomato_healthy', | |
| 'Tomato_Late_blight', | |
| 'Tomato_Leaf_Mold', | |
| 'Tomato_Septoria_leaf_spot', | |
| 'Tomato_Spider_mites' | |
| ] | |
| # Load model with memory optimization | |
| model_path = 'models/crop_disease_v3_model.pth' | |
| if os.path.exists(model_path): | |
| # Use lite version of model | |
| model = CropDiseaseResNet50Lite(num_classes=len(class_names), pretrained=False) | |
| # Load with memory mapping for large files | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=True) | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| if 'class_names' in checkpoint: | |
| class_names = checkpoint['class_names'] | |
| else: | |
| state_dict = checkpoint | |
| # Load state dict and immediately clear checkpoint from memory | |
| model.load_state_dict(state_dict, strict=False) | |
| del checkpoint, state_dict | |
| optimize_memory() | |
| model.to(device) | |
| model.eval() | |
| # Enable memory efficient mode | |
| if hasattr(model, 'set_memory_efficient'): | |
| model.set_memory_efficient(True) | |
| print(f"Lite model loaded from {model_path}") | |
| else: | |
| print("Warning: No trained model found. Creating lite model.") | |
| model = CropDiseaseResNet50Lite(num_classes=len(class_names), pretrained=True) | |
| model.to(device) | |
| model.eval() | |
| # Initialize lite explainer only if needed | |
| explainer = CropDiseaseExplainerLite(model, class_names, device) | |
| print("Lite explainer initialized") | |
| # Initialize risk calculator | |
| risk_calculator = RiskLevelCalculator() | |
| print("Risk calculator initialized") | |
| # Pre-load transforms | |
| transforms = get_inference_transforms(input_size=224) | |
| # Force memory cleanup | |
| optimize_memory() | |
| memory_usage = get_memory_usage() | |
| print(f"Memory usage after loading: {memory_usage:.1f} MB") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model and components: {e}") | |
| traceback.print_exc() | |
| return False | |
| async def startup_event(): | |
| """Initialize components on startup""" | |
| print("Starting optimized disease detection API...") | |
| success = load_model_and_components() | |
| if success: | |
| print("✅ All components loaded successfully") | |
| else: | |
| print("⚠️ Failed to load some components") | |
| async def root(): | |
| """Root endpoint""" | |
| memory_usage = get_memory_usage() | |
| return { | |
| "message": "Crop Disease Detection API (Optimized)", | |
| "version": "2.1.0", | |
| "status": "active", | |
| "memory_usage_mb": f"{memory_usage:.1f}", | |
| "optimization": "Memory optimized for <512MB usage", | |
| "endpoints": { | |
| "predict": "/predict - POST with image file", | |
| "health": "/health - GET for health check", | |
| "memory": "/memory - GET memory usage info" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint with memory info""" | |
| memory_usage = get_memory_usage() | |
| return { | |
| "status": "ok", | |
| "model_loaded": model is not None, | |
| "explainer_loaded": explainer is not None, | |
| "device": str(device) if device else "unknown", | |
| "memory_usage_mb": f"{memory_usage:.1f}", | |
| "memory_optimized": memory_usage < 512 | |
| } | |
| async def memory_info(): | |
| """Get detailed memory usage information""" | |
| memory_usage = get_memory_usage() | |
| process = psutil.Process(os.getpid()) | |
| memory_info = process.memory_info() | |
| return { | |
| "memory_usage_mb": f"{memory_usage:.1f}", | |
| "memory_percent": f"{process.memory_percent():.1f}%", | |
| "rss_mb": f"{memory_info.rss / 1024 / 1024:.1f}", | |
| "vms_mb": f"{memory_info.vms / 1024 / 1024:.1f}", | |
| "available_memory_mb": f"{psutil.virtual_memory().available / 1024 / 1024:.1f}", | |
| "gpu_memory_allocated": f"{torch.cuda.memory_allocated() / 1024 / 1024:.1f}" if torch.cuda.is_available() else "N/A", | |
| "optimization_status": "Optimized" if memory_usage < 512 else "Needs optimization" | |
| } | |
| async def predict_disease( | |
| file: UploadFile = File(...), | |
| include_explanation: bool = Form(False), | |
| weather_humidity: Optional[float] = Form(None), | |
| weather_temperature: Optional[float] = Form(None), | |
| weather_rainfall: Optional[float] = Form(None) | |
| ): | |
| """ | |
| Predict plant disease from uploaded image (memory optimized) | |
| """ | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Memory optimization: track usage | |
| initial_memory = get_memory_usage() | |
| # Read and validate image with memory limits | |
| contents = await file.read() | |
| if len(contents) > 5 * 1024 * 1024: # 5MB limit | |
| raise HTTPException(status_code=413, detail="Image too large. Maximum size: 5MB") | |
| # Process image with memory optimization | |
| image = Image.open(io.BytesIO(contents)) | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Resize to reduce memory usage | |
| max_size = 224 | |
| if image.size[0] > max_size or image.size[1] > max_size: | |
| image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
| # Apply transforms | |
| if transforms is None: | |
| transforms_fn = get_inference_transforms(input_size=224) | |
| else: | |
| transforms_fn = transforms | |
| input_tensor = transforms_fn(image).unsqueeze(0).to(device) | |
| # Clear image from memory | |
| del image, contents | |
| optimize_memory() | |
| # Prediction with memory optimization | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| confidence, predicted_idx = torch.max(probabilities, 1) | |
| predicted_class = class_names[predicted_idx.item()] | |
| confidence_score = confidence.item() | |
| # Get class probabilities (top 3 only to save memory) | |
| class_probs = {} | |
| top_probs, top_indices = torch.topk(probabilities[0], min(3, len(class_names))) | |
| for i, (prob, idx) in enumerate(zip(top_probs, top_indices)): | |
| class_probs[class_names[idx.item()]] = prob.item() | |
| # Clear tensors | |
| del input_tensor, outputs, probabilities | |
| optimize_memory() | |
| # Load disease information efficiently | |
| disease_info = get_disease_info_lite(predicted_class) | |
| # Calculate risk assessment | |
| weather_data = {} | |
| if weather_humidity is not None: | |
| weather_data['humidity'] = weather_humidity | |
| if weather_temperature is not None: | |
| weather_data['temperature'] = weather_temperature | |
| if weather_rainfall is not None: | |
| weather_data['rainfall'] = weather_rainfall | |
| risk_assessment = risk_calculator.calculate_risk( | |
| predicted_class, confidence_score, weather_data | |
| ) if risk_calculator else {"overall_risk": "unknown", "risk_factors": [], "recommendations": []} | |
| # Generate explanation only if requested and memory allows | |
| explanation_data = {} | |
| current_memory = get_memory_usage() | |
| if include_explanation and current_memory < 400 and explainer: # Only if we have memory headroom | |
| try: | |
| explanation_data = explainer.generate_explanation_lite( | |
| await file.read(), predicted_class | |
| ) | |
| except Exception as e: | |
| print(f"Explanation generation failed: {e}") | |
| explanation_data = {"error": "Explanation unavailable due to memory constraints"} | |
| elif include_explanation: | |
| explanation_data = {"error": "Explanation disabled due to memory constraints"} | |
| # Final memory cleanup | |
| optimize_memory() | |
| final_memory = get_memory_usage() | |
| # Prepare response | |
| result = { | |
| "predicted_class": predicted_class, | |
| "confidence": confidence_score, | |
| "class_probabilities": class_probs, | |
| "disease_info": disease_info, | |
| "risk_assessment": risk_assessment, | |
| "crop": extract_crop_name(predicted_class), | |
| "memory_usage": { | |
| "initial_mb": f"{initial_memory:.1f}", | |
| "final_mb": f"{final_memory:.1f}", | |
| "memory_optimized": final_memory < 512 | |
| } | |
| } | |
| if explanation_data: | |
| result["explanation"] = explanation_data | |
| return JSONResponse(content=result) | |
| except Exception as e: | |
| # Cleanup on error | |
| optimize_memory() | |
| print(f"Prediction error: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| def get_disease_info_lite(disease_class: str) -> Dict[str, Any]: | |
| """Get disease information with memory optimization""" | |
| try: | |
| # Load only essential disease info to save memory | |
| knowledge_base_path = Path(__file__).parent.parent / "knowledge_base" / "disease_info.json" | |
| if knowledge_base_path.exists(): | |
| with open(knowledge_base_path, 'r') as f: | |
| all_disease_info = json.load(f) | |
| # Get specific disease info | |
| disease_info = all_disease_info.get(disease_class, {}) | |
| # Return only essential fields to save memory | |
| return { | |
| "symptoms": disease_info.get("symptoms", [])[:3], # Limit to 3 symptoms | |
| "solutions": disease_info.get("solutions", [])[:3], # Limit to 3 solutions | |
| "prevention": disease_info.get("prevention", [])[:3], # Limit to 3 prevention methods | |
| "description": disease_info.get("description", "No description available")[:200] # Truncate description | |
| } | |
| except Exception as e: | |
| print(f"Error loading disease info: {e}") | |
| return { | |
| "symptoms": ["Symptoms information unavailable"], | |
| "solutions": ["Please consult agricultural expert"], | |
| "prevention": ["Follow general plant care guidelines"], | |
| "description": "Disease information unavailable" | |
| } | |
| def extract_crop_name(disease_class: str) -> str: | |
| """Extract crop name from disease class""" | |
| if disease_class.startswith(('Pepper', 'pepper')): | |
| return "Pepper" | |
| elif disease_class.startswith(('Potato', 'potato')): | |
| return "Potato" | |
| elif disease_class.startswith(('Tomato', 'tomato')): | |
| return "Tomato" | |
| else: | |
| return "Unknown" | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("Starting memory-optimized disease detection API...") | |
| print("Target: <512MB RAM usage") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8001, | |
| workers=1, # Single worker to save memory | |
| limit_concurrency=2 # Limit concurrent requests | |
| ) |