crop_ai_diseases / api /main_optimized.py
vivek12coder's picture
Upload 20960 files
c8df794 verified
"""
Memory-optimized FastAPI Backend for Crop Disease Detection
Optimized to use <512MB RAM
"""
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import torch
import torch.nn.functional as F
from PIL import Image
import io
import json
import sys
import os
from pathlib import Path
from typing import Optional, Dict, Any
import tempfile
import traceback
import gc
import psutil
# Add src to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
try:
from src.model import CropDiseaseResNet50Lite
from src.explain_lite import CropDiseaseExplainerLite
from src.risk_level import RiskLevelCalculator
from src.dataset import get_inference_transforms
except ImportError as e:
print(f"Import error: {e}")
print("Make sure all required modules are available")
# Initialize FastAPI app
app = FastAPI(
title="Crop Disease Detection API (Optimized)",
description="Memory-optimized AI-powered crop disease detection",
version="2.1.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for model and components
model = None
explainer = None
risk_calculator = None
class_names = []
device = None
transforms = None
def get_memory_usage():
"""Get current memory usage in MB"""
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
return memory_info.rss / 1024 / 1024 # Convert to MB
def optimize_memory():
"""Force garbage collection and clear GPU cache"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load_model_and_components():
"""Load trained model and initialize components with memory optimization"""
global model, explainer, risk_calculator, class_names, device, transforms
try:
# Set device - prefer CPU for memory efficiency
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).total_memory > 2e9:
device = torch.device('cuda')
else:
device = torch.device('cpu')
print(f"Using device: {device}")
# Optimized class names (reduced set for memory efficiency)
class_names = [
'Pepper_Bacterial_spot',
'Pepper_healthy',
'Potato_Early_blight',
'Potato_healthy',
'Potato_Late_blight',
'Tomato_Target_Spot',
'Tomato_mosaic_virus',
'Tomato_Yellow_Leaf_Curl',
'Tomato_Bacterial_spot',
'Tomato_Early_blight',
'Tomato_healthy',
'Tomato_Late_blight',
'Tomato_Leaf_Mold',
'Tomato_Septoria_leaf_spot',
'Tomato_Spider_mites'
]
# Load model with memory optimization
model_path = 'models/crop_disease_v3_model.pth'
if os.path.exists(model_path):
# Use lite version of model
model = CropDiseaseResNet50Lite(num_classes=len(class_names), pretrained=False)
# Load with memory mapping for large files
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
if 'class_names' in checkpoint:
class_names = checkpoint['class_names']
else:
state_dict = checkpoint
# Load state dict and immediately clear checkpoint from memory
model.load_state_dict(state_dict, strict=False)
del checkpoint, state_dict
optimize_memory()
model.to(device)
model.eval()
# Enable memory efficient mode
if hasattr(model, 'set_memory_efficient'):
model.set_memory_efficient(True)
print(f"Lite model loaded from {model_path}")
else:
print("Warning: No trained model found. Creating lite model.")
model = CropDiseaseResNet50Lite(num_classes=len(class_names), pretrained=True)
model.to(device)
model.eval()
# Initialize lite explainer only if needed
explainer = CropDiseaseExplainerLite(model, class_names, device)
print("Lite explainer initialized")
# Initialize risk calculator
risk_calculator = RiskLevelCalculator()
print("Risk calculator initialized")
# Pre-load transforms
transforms = get_inference_transforms(input_size=224)
# Force memory cleanup
optimize_memory()
memory_usage = get_memory_usage()
print(f"Memory usage after loading: {memory_usage:.1f} MB")
return True
except Exception as e:
print(f"Error loading model and components: {e}")
traceback.print_exc()
return False
@app.on_event("startup")
async def startup_event():
"""Initialize components on startup"""
print("Starting optimized disease detection API...")
success = load_model_and_components()
if success:
print("✅ All components loaded successfully")
else:
print("⚠️ Failed to load some components")
@app.get("/")
async def root():
"""Root endpoint"""
memory_usage = get_memory_usage()
return {
"message": "Crop Disease Detection API (Optimized)",
"version": "2.1.0",
"status": "active",
"memory_usage_mb": f"{memory_usage:.1f}",
"optimization": "Memory optimized for <512MB usage",
"endpoints": {
"predict": "/predict - POST with image file",
"health": "/health - GET for health check",
"memory": "/memory - GET memory usage info"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint with memory info"""
memory_usage = get_memory_usage()
return {
"status": "ok",
"model_loaded": model is not None,
"explainer_loaded": explainer is not None,
"device": str(device) if device else "unknown",
"memory_usage_mb": f"{memory_usage:.1f}",
"memory_optimized": memory_usage < 512
}
@app.get("/memory")
async def memory_info():
"""Get detailed memory usage information"""
memory_usage = get_memory_usage()
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
return {
"memory_usage_mb": f"{memory_usage:.1f}",
"memory_percent": f"{process.memory_percent():.1f}%",
"rss_mb": f"{memory_info.rss / 1024 / 1024:.1f}",
"vms_mb": f"{memory_info.vms / 1024 / 1024:.1f}",
"available_memory_mb": f"{psutil.virtual_memory().available / 1024 / 1024:.1f}",
"gpu_memory_allocated": f"{torch.cuda.memory_allocated() / 1024 / 1024:.1f}" if torch.cuda.is_available() else "N/A",
"optimization_status": "Optimized" if memory_usage < 512 else "Needs optimization"
}
@app.post("/predict")
async def predict_disease(
file: UploadFile = File(...),
include_explanation: bool = Form(False),
weather_humidity: Optional[float] = Form(None),
weather_temperature: Optional[float] = Form(None),
weather_rainfall: Optional[float] = Form(None)
):
"""
Predict plant disease from uploaded image (memory optimized)
"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Memory optimization: track usage
initial_memory = get_memory_usage()
# Read and validate image with memory limits
contents = await file.read()
if len(contents) > 5 * 1024 * 1024: # 5MB limit
raise HTTPException(status_code=413, detail="Image too large. Maximum size: 5MB")
# Process image with memory optimization
image = Image.open(io.BytesIO(contents))
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize to reduce memory usage
max_size = 224
if image.size[0] > max_size or image.size[1] > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Apply transforms
if transforms is None:
transforms_fn = get_inference_transforms(input_size=224)
else:
transforms_fn = transforms
input_tensor = transforms_fn(image).unsqueeze(0).to(device)
# Clear image from memory
del image, contents
optimize_memory()
# Prediction with memory optimization
with torch.no_grad():
outputs = model(input_tensor)
probabilities = F.softmax(outputs, dim=1)
confidence, predicted_idx = torch.max(probabilities, 1)
predicted_class = class_names[predicted_idx.item()]
confidence_score = confidence.item()
# Get class probabilities (top 3 only to save memory)
class_probs = {}
top_probs, top_indices = torch.topk(probabilities[0], min(3, len(class_names)))
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
class_probs[class_names[idx.item()]] = prob.item()
# Clear tensors
del input_tensor, outputs, probabilities
optimize_memory()
# Load disease information efficiently
disease_info = get_disease_info_lite(predicted_class)
# Calculate risk assessment
weather_data = {}
if weather_humidity is not None:
weather_data['humidity'] = weather_humidity
if weather_temperature is not None:
weather_data['temperature'] = weather_temperature
if weather_rainfall is not None:
weather_data['rainfall'] = weather_rainfall
risk_assessment = risk_calculator.calculate_risk(
predicted_class, confidence_score, weather_data
) if risk_calculator else {"overall_risk": "unknown", "risk_factors": [], "recommendations": []}
# Generate explanation only if requested and memory allows
explanation_data = {}
current_memory = get_memory_usage()
if include_explanation and current_memory < 400 and explainer: # Only if we have memory headroom
try:
explanation_data = explainer.generate_explanation_lite(
await file.read(), predicted_class
)
except Exception as e:
print(f"Explanation generation failed: {e}")
explanation_data = {"error": "Explanation unavailable due to memory constraints"}
elif include_explanation:
explanation_data = {"error": "Explanation disabled due to memory constraints"}
# Final memory cleanup
optimize_memory()
final_memory = get_memory_usage()
# Prepare response
result = {
"predicted_class": predicted_class,
"confidence": confidence_score,
"class_probabilities": class_probs,
"disease_info": disease_info,
"risk_assessment": risk_assessment,
"crop": extract_crop_name(predicted_class),
"memory_usage": {
"initial_mb": f"{initial_memory:.1f}",
"final_mb": f"{final_memory:.1f}",
"memory_optimized": final_memory < 512
}
}
if explanation_data:
result["explanation"] = explanation_data
return JSONResponse(content=result)
except Exception as e:
# Cleanup on error
optimize_memory()
print(f"Prediction error: {e}")
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
def get_disease_info_lite(disease_class: str) -> Dict[str, Any]:
"""Get disease information with memory optimization"""
try:
# Load only essential disease info to save memory
knowledge_base_path = Path(__file__).parent.parent / "knowledge_base" / "disease_info.json"
if knowledge_base_path.exists():
with open(knowledge_base_path, 'r') as f:
all_disease_info = json.load(f)
# Get specific disease info
disease_info = all_disease_info.get(disease_class, {})
# Return only essential fields to save memory
return {
"symptoms": disease_info.get("symptoms", [])[:3], # Limit to 3 symptoms
"solutions": disease_info.get("solutions", [])[:3], # Limit to 3 solutions
"prevention": disease_info.get("prevention", [])[:3], # Limit to 3 prevention methods
"description": disease_info.get("description", "No description available")[:200] # Truncate description
}
except Exception as e:
print(f"Error loading disease info: {e}")
return {
"symptoms": ["Symptoms information unavailable"],
"solutions": ["Please consult agricultural expert"],
"prevention": ["Follow general plant care guidelines"],
"description": "Disease information unavailable"
}
def extract_crop_name(disease_class: str) -> str:
"""Extract crop name from disease class"""
if disease_class.startswith(('Pepper', 'pepper')):
return "Pepper"
elif disease_class.startswith(('Potato', 'potato')):
return "Potato"
elif disease_class.startswith(('Tomato', 'tomato')):
return "Tomato"
else:
return "Unknown"
if __name__ == "__main__":
import uvicorn
print("Starting memory-optimized disease detection API...")
print("Target: <512MB RAM usage")
uvicorn.run(
app,
host="0.0.0.0",
port=8001,
workers=1, # Single worker to save memory
limit_concurrency=2 # Limit concurrent requests
)