|
|
""" |
|
|
Plant Disease Classification API with Robust OOD Detection |
|
|
Fixed confidence and OOD issues |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Dict, Optional, Tuple |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import timm |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import io |
|
|
import albumentations as A |
|
|
from albumentations.pytorch import ToTensorV2 |
|
|
import logging |
|
|
from scipy.stats import norm |
|
|
import pickle |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Config: |
|
|
MODEL_PATH = "best_model_final.pth" |
|
|
STATS_PATH = "class_statistics.pkl" |
|
|
IMG_SIZE = 224 |
|
|
|
|
|
CONFIDENCE_THRESHOLD = 0.3 |
|
|
OOD_THRESHOLD = 0.15 |
|
|
ENTROPY_THRESHOLD = 1.5 |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
USE_MAHALANOBIS = False |
|
|
USE_ENSEMBLE = False |
|
|
|
|
|
|
|
|
CLASS_NAMES = [ |
|
|
'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', |
|
|
'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', |
|
|
'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', |
|
|
'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', |
|
|
'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', |
|
|
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', |
|
|
'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', |
|
|
'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', |
|
|
'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', |
|
|
'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', |
|
|
'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', |
|
|
'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', |
|
|
'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', |
|
|
'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', |
|
|
'Tomato___Tomato_mosaic_virus', 'Tomato___healthy' |
|
|
] |
|
|
|
|
|
config = Config() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PlantDiseaseModel(nn.Module): |
|
|
"""EfficientNet-B0 with custom classifier and feature extraction""" |
|
|
def __init__(self, num_classes, dropout=0.4): |
|
|
super(PlantDiseaseModel, self).__init__() |
|
|
|
|
|
self.backbone = timm.create_model('efficientnet_b0', pretrained=True) |
|
|
num_features = self.backbone.classifier.in_features |
|
|
|
|
|
|
|
|
self.backbone.classifier = nn.Identity() |
|
|
|
|
|
|
|
|
self.feature_dim = num_features |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(num_features, 512), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.BatchNorm1d(512), |
|
|
nn.Dropout(dropout * 0.5), |
|
|
nn.Linear(512, num_classes) |
|
|
) |
|
|
|
|
|
def forward(self, x, return_features=False): |
|
|
features = self.backbone(x) |
|
|
logits = self.classifier(features) |
|
|
|
|
|
if return_features: |
|
|
return logits, features |
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OODDetector: |
|
|
"""Multiple methods for robust OOD detection""" |
|
|
|
|
|
def __init__(self): |
|
|
self.methods = ['confidence', 'entropy', 'energy'] |
|
|
|
|
|
@staticmethod |
|
|
def compute_entropy(probs: torch.Tensor) -> float: |
|
|
"""Compute entropy of probability distribution""" |
|
|
return -torch.sum(probs * torch.log(probs + 1e-10)).item() |
|
|
|
|
|
@staticmethod |
|
|
def compute_energy_score(logits: torch.Tensor, temperature: float = 1.0) -> float: |
|
|
"""Energy-based OOD detection""" |
|
|
return -temperature * torch.logsumexp(logits / temperature, dim=1).item() |
|
|
|
|
|
@staticmethod |
|
|
def compute_max_softmax(probs: torch.Tensor) -> float: |
|
|
"""Maximum softmax probability""" |
|
|
return torch.max(probs).item() |
|
|
|
|
|
def detect_ood(self, logits: torch.Tensor, method: str = 'ensemble') -> Tuple[bool, Dict]: |
|
|
""" |
|
|
Detect OOD using multiple methods |
|
|
Returns: (is_ood, scores_dict) |
|
|
""" |
|
|
probs = F.softmax(logits, dim=1) |
|
|
|
|
|
scores = { |
|
|
'confidence': self.compute_max_softmax(probs), |
|
|
'entropy': self.compute_entropy(probs[0]), |
|
|
'energy': self.compute_energy_score(logits) |
|
|
} |
|
|
|
|
|
|
|
|
is_ood = ( |
|
|
scores['confidence'] < config.CONFIDENCE_THRESHOLD or |
|
|
scores['entropy'] > config.ENTROPY_THRESHOLD or |
|
|
scores['energy'] > 10.0 |
|
|
) |
|
|
|
|
|
return is_ood, scores |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_transform(augment: bool = False): |
|
|
"""Get image preprocessing transform matching training""" |
|
|
if augment: |
|
|
return A.Compose([ |
|
|
A.Resize(config.IMG_SIZE, config.IMG_SIZE), |
|
|
A.HorizontalFlip(p=0.5), |
|
|
A.RandomBrightnessContrast(p=0.2), |
|
|
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
ToTensorV2(), |
|
|
]) |
|
|
else: |
|
|
return A.Compose([ |
|
|
A.Resize(config.IMG_SIZE, config.IMG_SIZE), |
|
|
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
ToTensorV2(), |
|
|
]) |
|
|
|
|
|
def preprocess_image(image_bytes: bytes, augment: bool = False) -> torch.Tensor: |
|
|
"""Preprocess uploaded image with validation""" |
|
|
try: |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
|
|
|
|
|
|
if image.size[0] < 50 or image.size[1] < 50: |
|
|
logger.warning(f"Image too small: {image.size}") |
|
|
|
|
|
image_np = np.array(image) |
|
|
transform = get_transform(augment) |
|
|
augmented = transform(image=image_np) |
|
|
image_tensor = augmented['image'].unsqueeze(0) |
|
|
return image_tensor |
|
|
except Exception as e: |
|
|
logger.error(f"Error preprocessing image: {e}") |
|
|
raise HTTPException(status_code=400, detail=f"Invalid image format: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""Load trained model with proper initialization""" |
|
|
try: |
|
|
logger.info(f"Loading model from {config.MODEL_PATH}") |
|
|
model = PlantDiseaseModel(num_classes=len(config.CLASS_NAMES), dropout=0.4) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(config.MODEL_PATH, map_location=config.DEVICE, weights_only=False) |
|
|
|
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
|
state_dict = checkpoint['model_state_dict'] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
model.to(config.DEVICE) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
ood_detector = OODDetector() |
|
|
|
|
|
logger.info(f"✅ Model loaded successfully on {config.DEVICE}") |
|
|
if 'epoch' in checkpoint and 'val_acc' in checkpoint: |
|
|
logger.info(f" Epoch: {checkpoint['epoch']}, Val Acc: {checkpoint['val_acc']:.2f}%") |
|
|
|
|
|
return model, ood_detector |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
|
|
|
logger.info("Trying fallback with pretrained backbone...") |
|
|
model = PlantDiseaseModel(num_classes=len(config.CLASS_NAMES), dropout=0.4) |
|
|
model.to(config.DEVICE) |
|
|
model.eval() |
|
|
ood_detector = OODDetector() |
|
|
return model, ood_detector |
|
|
|
|
|
|
|
|
model, ood_detector = load_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PredictionResult(BaseModel): |
|
|
"""Response model for successful prediction""" |
|
|
status: str |
|
|
prediction: str |
|
|
confidence: float |
|
|
plant: str |
|
|
disease: str |
|
|
is_healthy: bool |
|
|
top3_predictions: List[Dict[str, float]] |
|
|
recommendations: Optional[str] = None |
|
|
ood_scores: Optional[Dict] = None |
|
|
|
|
|
class OODResult(BaseModel): |
|
|
"""Response model for OOD detection""" |
|
|
status: str |
|
|
message: str |
|
|
confidence: float |
|
|
entropy: float |
|
|
top_guess: Optional[str] = None |
|
|
note: str |
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
"""Health check response""" |
|
|
status: str |
|
|
model_loaded: bool |
|
|
device: str |
|
|
classes: int |
|
|
confidence_threshold: float |
|
|
ood_threshold: float |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def predict_image(image_tensor: torch.Tensor) -> Dict: |
|
|
""" |
|
|
Make prediction with robust OOD detection |
|
|
""" |
|
|
image_tensor = image_tensor.to(config.DEVICE) |
|
|
|
|
|
|
|
|
logits, features = model(image_tensor, return_features=True) |
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=1) |
|
|
confidence, pred_idx = torch.max(probs, dim=1) |
|
|
confidence = confidence.item() |
|
|
pred_idx = pred_idx.item() |
|
|
|
|
|
|
|
|
topk = min(3, len(config.CLASS_NAMES)) |
|
|
topk_probs, topk_indices = torch.topk(probs, topk) |
|
|
topk_probs = topk_probs.cpu().numpy()[0] |
|
|
topk_indices = topk_indices.cpu().numpy()[0] |
|
|
|
|
|
|
|
|
is_ood, ood_scores = ood_detector.detect_ood(logits) |
|
|
|
|
|
|
|
|
predicted_class = config.CLASS_NAMES[pred_idx] |
|
|
is_predicted_healthy = 'healthy' in predicted_class.lower() |
|
|
|
|
|
|
|
|
if is_predicted_healthy and confidence > 0.2 and not is_ood: |
|
|
is_ood = False |
|
|
|
|
|
|
|
|
if is_ood or confidence < config.OOD_THRESHOLD: |
|
|
return { |
|
|
"status": "OOD", |
|
|
"message": "⚠️ Unable to identify plant disease", |
|
|
"confidence": round(confidence, 4), |
|
|
"entropy": round(ood_scores['entropy'], 4), |
|
|
"top_guess": config.CLASS_NAMES[pred_idx] if confidence > 0.1 else "Unknown", |
|
|
"note": "This doesn't appear to be a clear plant leaf image. Please upload a focused image of a plant leaf against a neutral background." |
|
|
} |
|
|
|
|
|
|
|
|
parts = predicted_class.split('___') |
|
|
plant = parts[0].replace('_', ' ').strip() |
|
|
disease = parts[1].replace('_', ' ').strip() if len(parts) > 1 else "Unknown" |
|
|
is_healthy = 'healthy' in disease.lower() |
|
|
|
|
|
|
|
|
recommendations = get_recommendations(plant, disease, is_healthy) |
|
|
|
|
|
|
|
|
top_predictions = [ |
|
|
{ |
|
|
"class": config.CLASS_NAMES[idx], |
|
|
"confidence": round(float(prob), 4) |
|
|
} |
|
|
for idx, prob in zip(topk_indices, topk_probs) |
|
|
] |
|
|
|
|
|
|
|
|
response = { |
|
|
"status": "OK", |
|
|
"prediction": predicted_class, |
|
|
"confidence": round(confidence, 4), |
|
|
"plant": plant, |
|
|
"disease": disease, |
|
|
"is_healthy": is_healthy, |
|
|
"top3_predictions": top_predictions, |
|
|
"recommendations": recommendations |
|
|
} |
|
|
|
|
|
|
|
|
if logger.getEffectiveLevel() <= logging.DEBUG: |
|
|
response["ood_scores"] = {k: round(v, 4) for k, v in ood_scores.items()} |
|
|
|
|
|
return response |
|
|
|
|
|
def get_recommendations(plant: str, disease: str, is_healthy: bool) -> str: |
|
|
"""Generate treatment recommendations""" |
|
|
if is_healthy: |
|
|
return f"✅ Your {plant} plant appears healthy! Continue regular care and monitoring." |
|
|
|
|
|
|
|
|
recommendations_db = { |
|
|
|
|
|
"Apple scab": "Apply fungicides in early spring, remove fallen leaves, prune for air circulation.", |
|
|
"Black rot": "Remove infected fruit and wood, apply fungicide during bloom, avoid overhead irrigation.", |
|
|
"Cedar apple rust": "Remove nearby junipers, apply fungicide in spring, plant resistant varieties.", |
|
|
|
|
|
|
|
|
"Early blight": "Remove affected leaves, apply chlorothalonil or copper fungicide, rotate crops.", |
|
|
"Late blight": "REMOVE AND DESTROY infected plants immediately. Apply copper fungicide preventively.", |
|
|
"Bacterial spot": "Use copper-based bactericides, avoid overhead watering, use pathogen-free seeds.", |
|
|
"Leaf Mold": "Improve ventilation, reduce humidity, apply fungicide, remove affected leaves.", |
|
|
"Septoria leaf spot": "Remove infected leaves, apply chlorothalonil, avoid watering foliage.", |
|
|
|
|
|
|
|
|
"Black rot": "Remove infected fruit, apply fungicide during bloom, ensure good air circulation.", |
|
|
|
|
|
|
|
|
"Common rust": "Plant resistant varieties, apply fungicide if detected early, rotate crops.", |
|
|
"Northern Leaf Blight": "Till infected debris, rotate crops, apply fungicide during silking.", |
|
|
|
|
|
|
|
|
"Powdery mildew": "Improve air circulation, apply sulfur or potassium bicarbonate, avoid excess nitrogen.", |
|
|
"Bacterial spot": "Use copper sprays, avoid working with wet plants, sanitize tools.", |
|
|
"Leaf scorch": "Ensure adequate watering, mulch to retain moisture, protect from hot winds.", |
|
|
"mosaic virus": "Remove infected plants, control aphids, use virus-free planting material.", |
|
|
"Yellow Leaf Curl Virus": "Control whiteflies, remove infected plants, use resistant varieties.", |
|
|
} |
|
|
|
|
|
|
|
|
for key, rec in recommendations_db.items(): |
|
|
if key.lower() == disease.lower(): |
|
|
return f"⚠️ **{disease}** detected on {plant}. Recommendations: {rec}" |
|
|
|
|
|
|
|
|
for key, rec in recommendations_db.items(): |
|
|
if key.lower() in disease.lower() or disease.lower() in key.lower(): |
|
|
return f"⚠️ **{disease}** detected on {plant}. Recommendations: {rec}" |
|
|
|
|
|
|
|
|
return f"⚠️ **{disease}** detected on {plant}. Remove affected leaves, improve air circulation, and consult local agricultural extension for specific treatment." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Plant Disease Detection API", |
|
|
description="AI-powered plant disease classification with robust OOD detection", |
|
|
version="2.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/", response_model=HealthResponse) |
|
|
async def root(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "✅ API is running with improved OOD detection", |
|
|
"model_loaded": model is not None, |
|
|
"device": config.DEVICE, |
|
|
"classes": len(config.CLASS_NAMES), |
|
|
"confidence_threshold": config.CONFIDENCE_THRESHOLD, |
|
|
"ood_threshold": config.OOD_THRESHOLD |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Detailed health check""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model": "EfficientNet-B0 with OOD detection", |
|
|
"device": config.DEVICE, |
|
|
"classes": len(config.CLASS_NAMES), |
|
|
"ood_methods": ood_detector.methods, |
|
|
"confidence_threshold": config.CONFIDENCE_THRESHOLD, |
|
|
"entropy_threshold": config.ENTROPY_THRESHOLD, |
|
|
"note": "Confidence thresholds adjusted for 38-class problem" |
|
|
} |
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(file: UploadFile = File(...)): |
|
|
""" |
|
|
Predict plant disease with improved OOD detection |
|
|
|
|
|
Key improvements: |
|
|
1. Lower confidence threshold (0.3) for 38-class problem |
|
|
2. Multiple OOD detection methods |
|
|
3. Special handling for 'healthy' class |
|
|
4. Better error messages |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not file.content_type.startswith('image/'): |
|
|
raise HTTPException(status_code=400, detail="File must be an image (JPEG, PNG, etc.)") |
|
|
|
|
|
|
|
|
file.file.seek(0, 2) |
|
|
file_size = file.file.tell() |
|
|
file.file.seek(0) |
|
|
|
|
|
if file_size > 10 * 1024 * 1024: |
|
|
raise HTTPException(status_code=400, detail="Image too large (max 10MB)") |
|
|
|
|
|
|
|
|
image_bytes = await file.read() |
|
|
image_tensor = preprocess_image(image_bytes) |
|
|
|
|
|
|
|
|
result = predict_image(image_tensor) |
|
|
|
|
|
|
|
|
if result["status"] == "OOD": |
|
|
logger.warning(f"OOD detected: {result['confidence']} confidence, {result['entropy']} entropy") |
|
|
else: |
|
|
logger.info(f"Prediction: {result['prediction']} ({result['confidence']:.2%})") |
|
|
|
|
|
return JSONResponse(content=result) |
|
|
|
|
|
except HTTPException as e: |
|
|
raise e |
|
|
except Exception as e: |
|
|
logger.error(f"Prediction error: {str(e)}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |
|
|
|
|
|
@app.post("/predict/batch") |
|
|
async def predict_batch(files: List[UploadFile] = File(...)): |
|
|
"""Predict multiple images""" |
|
|
if len(files) > 5: |
|
|
raise HTTPException(status_code=400, detail="Maximum 5 images per batch") |
|
|
|
|
|
results = [] |
|
|
for file in files: |
|
|
try: |
|
|
image_bytes = await file.read() |
|
|
image_tensor = preprocess_image(image_bytes) |
|
|
result = predict_image(image_tensor) |
|
|
result['filename'] = file.filename |
|
|
results.append(result) |
|
|
except Exception as e: |
|
|
results.append({ |
|
|
"filename": file.filename, |
|
|
"status": "ERROR", |
|
|
"message": str(e)[:100] |
|
|
}) |
|
|
|
|
|
return JSONResponse(content={"predictions": results}) |
|
|
|
|
|
@app.get("/debug/ood") |
|
|
async def debug_ood(): |
|
|
"""Debug endpoint to check OOD thresholds""" |
|
|
return { |
|
|
"confidence_threshold": config.CONFIDENCE_THRESHOLD, |
|
|
"ood_threshold": config.OOD_THRESHOLD, |
|
|
"entropy_threshold": config.ENTROPY_THRESHOLD, |
|
|
"note": "For 38 classes, even correct predictions often have 30-60% confidence" |
|
|
} |
|
|
|
|
|
@app.get("/classes/stats") |
|
|
async def class_statistics(): |
|
|
"""Get class statistics""" |
|
|
healthy_classes = [c for c in config.CLASS_NAMES if 'healthy' in c] |
|
|
disease_classes = [c for c in config.CLASS_NAMES if 'healthy' not in c] |
|
|
|
|
|
return { |
|
|
"total": len(config.CLASS_NAMES), |
|
|
"healthy_classes": len(healthy_classes), |
|
|
"disease_classes": len(disease_classes), |
|
|
"plants": list(set([c.split('___')[0] for c in config.CLASS_NAMES])) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
logger.info("Starting server with improved OOD detection...") |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |