Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import logging | |
| from typing import Dict, Optional, Any, List | |
| from pydantic import BaseModel, Field, field_validator | |
| from contextlib import asynccontextmanager | |
| from rapidfuzz import process, fuzz | |
| import urllib.parse | |
| import cv2 | |
| from sklearn.cluster import KMeans | |
| from collections import Counter | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from bleach import clean | |
| import numpy as np | |
| import tensorflow as tf | |
| from fastapi import FastAPI, File, Path, Query, UploadFile, HTTPException, status | |
| from PIL import Image | |
| import io | |
| from huggingface_hub import hf_hub_download | |
| from pydantic import BaseModel, Field | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| HF_MODEL_REPO: str = os.getenv("HF_MODEL_REPO", "yasyn14/smart-leaf-model") | |
| HF_MODEL_FILENAME: str = os.getenv("HF_MODEL_FILENAME", "best_model_32epochs.keras") | |
| HF_CACHE_DIR: str = os.getenv("HF_HOME", "/home/appuser/huggingface") | |
| IMAGE_SIZE: tuple = (300, 300) | |
| MAX_FILE_SIZE_MB: int = 10 | |
| CONFIDENCE_THRESHOLD: float = 0.5 | |
| # Plant disease class names - these are the actual class indices that the model outputs | |
| CLASS_NAMES = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", "24", "25", "26", "27", "28", "29", "30", "31", "32", "33", "34", "35", "36", "37"] | |
| # HTTP Status Messages | |
| HTTP_MESSAGES = { | |
| "MODEL_NOT_LOADED": "Model not loaded. Please check server logs.", | |
| "INVALID_FILE_TYPE": "File must be an image", | |
| "FILE_TOO_LARGE": f"File size exceeds {MAX_FILE_SIZE_MB}MB limit", | |
| "PREDICTION_FAILED": "Prediction failed: {error}", | |
| "IMAGE_PROCESSING_FAILED": "Error preprocessing image: {error}", | |
| "MODEL_LOAD_SUCCESS": "Model loaded successfully", | |
| "MODEL_LOAD_FAILED": "Failed to load model", | |
| "LOW_CONFIDENCE": "Prediction confidence is low. Please try a clearer image." | |
| } | |
| # Global model variable | |
| model: Optional[tf.keras.Model] = None | |
| disease_guide: Dict[str, Dict[str, Any]] = {} | |
| # Response models with improved validation | |
| class DiseaseInfo(BaseModel): | |
| disease_name: Optional[str] = None | |
| common_names: List[str] = [] | |
| crop: str = "Unknown" | |
| description: str = "No description available" | |
| symptoms: List[str] = [] | |
| cause: Optional[str] = None | |
| treatment: List[str] = [] | |
| image_urls: List[str] = [] | |
| prevention: List[str] = [] | |
| management_tips: str = "" | |
| risk_level: str = "Unknown" | |
| sprayer_intervals: str = "" | |
| localized_tips: str = "" | |
| type: str = "Unknown" | |
| external_resources: List[Dict[str, str]] = [] | |
| is_healthy: bool = False | |
| def validate_external_resources(cls, v): | |
| if v is None: | |
| return [] | |
| if isinstance(v, list): | |
| validated_resources = [] | |
| for item in v: | |
| if isinstance(item, dict): | |
| resource = { | |
| 'title': item.get('title', ''), | |
| 'url': item.get('url', '') | |
| } | |
| validated_resources.append(resource) | |
| return validated_resources | |
| return [] | |
| def validate_all_fields(cls, v, info): | |
| field_name = info.field_name | |
| if v is None: | |
| if field_name in ['disease_name', 'cause']: | |
| return None | |
| elif field_name in ['common_names', 'symptoms', 'treatment', 'image_urls', 'prevention', 'external_resources']: | |
| return [] | |
| elif field_name in ['crop', 'description', 'management_tips', 'risk_level', 'sprayer_intervals', 'localized_tips', 'type']: | |
| return info.default if hasattr(info, 'default') else "Unknown" | |
| elif field_name == 'is_healthy': | |
| return False | |
| return v | |
| class PredictionItem(BaseModel): | |
| confidence: float | |
| label: str | |
| confidence_level: str | |
| class PredictionResponse(BaseModel): | |
| success: bool | |
| predicted_class: str | |
| predicted_class_index: int | |
| clean_class_name: str = Field(description="Human-readable class name") | |
| confidence: float | |
| confidence_level: str = Field(description="High/Medium/Low confidence level") | |
| all_predictions: list[PredictionItem] = Field(description="Top 5 predictions with confidence scores") | |
| disease_info: DiseaseInfo | |
| recommendations: List[str] = Field(description="Action recommendations based on prediction") | |
| message: str | |
| class_id: str = Field(description="URL-safe class identifier") | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| total_classes: int | |
| available_diseases: int | |
| healthy_classes: int | |
| message: str | |
| class SearchResult(BaseModel): | |
| class_name: str | |
| class_id: str = Field(description="URL-safe class identifier") | |
| disease_info: DiseaseInfo | |
| relevance_score: Optional[float] = None | |
| class SearchResponse(BaseModel): | |
| results: List[SearchResult] | |
| suggestions: List[SearchResult] = [] | |
| total_results: int | |
| message: str = "" | |
| class LeafValidationResponse(BaseModel): | |
| is_leaf: bool | |
| confidence: float | |
| reason: str | |
| validation_method: str | |
| # Add these constants | |
| LEAF_VALIDATION_ENABLED = True | |
| MIN_GREEN_PERCENTAGE = 15 # Minimum % of green pixels | |
| MIN_EDGE_DENSITY = 0.1 # Minimum edge density for leaf texture | |
| MAX_UNIFORM_COLOR_PERCENTAGE = 80 # Max % of dominant color (to avoid solid backgrounds) | |
| def detect_green_content(image_array: np.ndarray) -> tuple[float, str]: | |
| """ | |
| Detect green content percentage in the image | |
| Returns (green_percentage, reason) | |
| """ | |
| try: | |
| # Convert from normalized [0,1] to [0,255] if needed | |
| if image_array.max() <= 1.0: | |
| image_array = (image_array * 255).astype(np.uint8) | |
| # Convert RGB to HSV for better green detection | |
| hsv = cv2.cvtColor(image_array, cv2.COLOR_RGB2HSV) | |
| # Define green color range in HSV | |
| # Broader range to catch different shades of green | |
| lower_green1 = np.array([35, 40, 40]) # Light green | |
| upper_green1 = np.array([85, 255, 255]) # Dark green | |
| # Create mask for green colors | |
| green_mask = cv2.inRange(hsv, lower_green1, upper_green1) | |
| # Calculate green percentage | |
| total_pixels = green_mask.shape[0] * green_mask.shape[1] | |
| green_pixels = np.sum(green_mask > 0) | |
| green_percentage = (green_pixels / total_pixels) * 100 | |
| reason = f"Green content: {green_percentage:.1f}%" | |
| return green_percentage, reason | |
| except Exception as e: | |
| logger.warning(f"Green detection failed: {e}") | |
| return 0.0, "Green detection failed" | |
| def detect_edge_density(image_array: np.ndarray) -> tuple[float, str]: | |
| """ | |
| Detect edge density which is typically high in leaf images due to veins and texture | |
| """ | |
| try: | |
| # Convert to grayscale | |
| if len(image_array.shape) == 3: | |
| if image_array.max() <= 1.0: | |
| image_array = (image_array * 255).astype(np.uint8) | |
| gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY) | |
| else: | |
| gray = image_array | |
| # Apply Canny edge detection | |
| edges = cv2.Canny(gray, 50, 150) | |
| # Calculate edge density | |
| total_pixels = edges.shape[0] * edges.shape[1] | |
| edge_pixels = np.sum(edges > 0) | |
| edge_density = edge_pixels / total_pixels | |
| reason = f"Edge density: {edge_density:.3f}" | |
| return edge_density, reason | |
| except Exception as e: | |
| logger.warning(f"Edge detection failed: {e}") | |
| return 0.0, "Edge detection failed" | |
| def detect_color_diversity(image_array: np.ndarray) -> tuple[float, str]: | |
| """ | |
| Detect color diversity - leaves typically have varied colors while non-leaves might be uniform | |
| """ | |
| try: | |
| if image_array.max() <= 1.0: | |
| image_array = (image_array * 255).astype(np.uint8) | |
| # Reshape image to list of pixels | |
| pixels = image_array.reshape(-1, 3) | |
| # Use KMeans to find dominant colors | |
| kmeans = KMeans(n_clusters=5, random_state=42, n_init=10) | |
| kmeans.fit(pixels) | |
| # Get color counts | |
| labels = kmeans.labels_ | |
| label_counts = Counter(labels) | |
| # Calculate dominant color percentage | |
| total_pixels = len(pixels) | |
| max_color_count = max(label_counts.values()) | |
| dominant_color_percentage = (max_color_count / total_pixels) * 100 | |
| reason = f"Dominant color: {dominant_color_percentage:.1f}%" | |
| return dominant_color_percentage, reason | |
| except Exception as e: | |
| logger.warning(f"Color diversity detection failed: {e}") | |
| return 100.0, "Color diversity detection failed" | |
| def validate_leaf_image(image_array: np.ndarray) -> LeafValidationResponse: | |
| """ | |
| Comprehensive leaf validation using multiple heuristics | |
| """ | |
| if not LEAF_VALIDATION_ENABLED: | |
| return LeafValidationResponse( | |
| is_leaf=True, | |
| confidence=1.0, | |
| reason="Validation disabled", | |
| validation_method="disabled" | |
| ) | |
| # Remove batch dimension if present | |
| if len(image_array.shape) == 4: | |
| image_array = image_array[0] | |
| validations = [] | |
| reasons = [] | |
| # 1. Green content check | |
| green_percentage, green_reason = detect_green_content(image_array) | |
| is_green_valid = green_percentage >= MIN_GREEN_PERCENTAGE | |
| validations.append(is_green_valid) | |
| reasons.append(green_reason) | |
| # 2. Edge density check (leaf texture) | |
| edge_density, edge_reason = detect_edge_density(image_array) | |
| is_edge_valid = edge_density >= MIN_EDGE_DENSITY | |
| validations.append(is_edge_valid) | |
| reasons.append(edge_reason) | |
| # 3. Color diversity check (avoid solid backgrounds) | |
| dominant_color_percentage, color_reason = detect_color_diversity(image_array) | |
| is_color_diverse = dominant_color_percentage <= MAX_UNIFORM_COLOR_PERCENTAGE | |
| validations.append(is_color_diverse) | |
| reasons.append(color_reason) | |
| # Calculate overall confidence | |
| valid_count = sum(validations) | |
| total_checks = len(validations) | |
| confidence = valid_count / total_checks | |
| # Determine if it's likely a leaf (at least 2 out of 3 checks should pass) | |
| is_leaf = valid_count >= 2 | |
| # Create detailed reason | |
| detailed_reason = f"Validation checks: {valid_count}/{total_checks} passed. " + "; ".join(reasons) | |
| return LeafValidationResponse( | |
| is_leaf=is_leaf, | |
| confidence=confidence, | |
| reason=detailed_reason, | |
| validation_method="heuristic_multi_check" | |
| ) | |
| def create_class_id(class_name: str) -> str: | |
| """Create a URL-safe identifier from class name""" | |
| return urllib.parse.quote(class_name, safe='') | |
| def decode_class_id(class_id: str) -> str: | |
| """Decode URL-safe identifier back to class name""" | |
| return urllib.parse.unquote(class_id) | |
| def load_disease_guide() -> Dict[str, Dict[str, Any]]: | |
| """Load disease guide from JSON file with error handling""" | |
| try: | |
| guide_path = "disease_guide.json" | |
| if not os.path.exists(guide_path): | |
| logger.warning(f"Disease guide file not found at {guide_path}") | |
| return {} | |
| with open(guide_path, 'r', encoding='utf-8') as f: | |
| guide = json.load(f) | |
| logger.info(f"Loaded disease guide with {len(guide)} entries") | |
| return guide | |
| except Exception as e: | |
| logger.error(f"Failed to load disease guide: {str(e)}") | |
| return {} | |
| def clean_class_name(class_index: str, disease_info: Optional[Dict[str, Any]] = None) -> str: | |
| """Convert class index to human-readable format""" | |
| if disease_info and disease_info.get('disease_name'): | |
| # Use the disease name from the JSON if available | |
| disease_name = disease_info['disease_name'] | |
| crop = disease_info.get('crop', 'Unknown') | |
| return f"{crop} - {disease_name}" | |
| else: | |
| # For healthy plants or unknown diseases | |
| return f"Class {class_index} (Healthy/Unknown)" | |
| def get_confidence_level(confidence: float) -> str: | |
| """Categorize confidence level""" | |
| if confidence >= 0.8: | |
| return "High" | |
| elif confidence >= 0.6: | |
| return "Medium" | |
| else: | |
| return "Low" | |
| def sanitize_search_query(query: str) -> str: | |
| """Sanitize search input""" | |
| return clean(query.strip(), tags=[], strip=True)[:100] # Limit length | |
| def safe_create_disease_info(class_index: str, disease_data: Optional[Dict[str, Any]] = None) -> DiseaseInfo: | |
| """Safely create DiseaseInfo object with proper validation and defaults""" | |
| try: | |
| # Set up base defaults to always match DiseaseInfo model | |
| base_defaults = { | |
| 'disease_name': None, | |
| 'common_names': [], | |
| 'crop': "Unknown", | |
| 'description': f"This appears to be a healthy plant or an unrecognized condition for class {class_index}", | |
| 'symptoms': [], | |
| 'cause': None, | |
| 'treatment': [], | |
| 'image_urls': [], | |
| 'prevention': [], | |
| 'management_tips': "", | |
| 'risk_level': "Unknown", | |
| 'sprayer_intervals': "", | |
| 'localized_tips': "", | |
| 'type': "Healthy/Unknown", | |
| 'external_resources': [], | |
| 'is_healthy': True | |
| } | |
| if not disease_data: | |
| return DiseaseInfo(**base_defaults) | |
| # Use defaults but override with any provided disease data | |
| safe_data = disease_data.copy() | |
| final_data = { | |
| 'disease_name': safe_data.get('disease_name'), | |
| 'common_names': safe_data.get('common_names', []), | |
| 'crop': safe_data.get('crop', 'Unknown'), | |
| 'description': safe_data.get('description', 'No description available'), | |
| 'symptoms': safe_data.get('symptoms', []), | |
| 'cause': safe_data.get('cause'), | |
| 'treatment': safe_data.get('treatment', []), | |
| 'image_urls': safe_data.get('image_urls', []), | |
| 'prevention': safe_data.get('prevention', []), | |
| 'management_tips': safe_data.get('management_tips', ''), | |
| 'risk_level': safe_data.get('risk_level', 'Unknown'), | |
| 'sprayer_intervals': safe_data.get('sprayer_intervals', ''), | |
| 'localized_tips': safe_data.get('localized_tips', ''), | |
| 'type': safe_data.get('type', 'Unknown'), | |
| 'external_resources': [], | |
| 'is_healthy': False | |
| } | |
| # Validate and normalize external_resources | |
| external_resources = safe_data.get('external_resources', []) | |
| if isinstance(external_resources, list): | |
| final_data['external_resources'] = [ | |
| { | |
| 'title': str(res.get('title', '')), | |
| 'url': str(res.get('url', '')) | |
| } | |
| for res in external_resources if isinstance(res, dict) | |
| ] | |
| return DiseaseInfo(**final_data) | |
| except Exception as e: | |
| logger.error(f"Error creating DiseaseInfo for class {class_index}: {str(e)}") | |
| logger.error(f"Data causing error: {disease_data}") | |
| # Return a safe fallback object with all required fields | |
| return DiseaseInfo( | |
| disease_name="Unknown", | |
| common_names=[], | |
| crop="Unknown", | |
| description=f"Error loading disease information for class {class_index}", | |
| symptoms=[], | |
| cause="Unknown", | |
| treatment=[], | |
| image_urls=[], | |
| prevention=[], | |
| management_tips="", | |
| risk_level="Unknown", | |
| sprayer_intervals="", | |
| localized_tips="", | |
| type="Unknown", | |
| external_resources=[], | |
| is_healthy=False | |
| ) | |
| def get_recommendations(class_index: str, confidence: float, disease_info: DiseaseInfo) -> List[str]: | |
| """Generate actionable recommendations based on prediction using treatment and prevention from JSON""" | |
| recommendations = [] | |
| # Add confidence-based recommendations first | |
| if confidence < CONFIDENCE_THRESHOLD: | |
| recommendations.extend([ | |
| "⚠️ Low confidence prediction - consider taking a clearer, well-lit photo", | |
| "📸 Ensure the leaf/plant fills most of the frame and is in focus", | |
| "💡 Try taking photos in natural light for better results" | |
| ]) | |
| if disease_info.is_healthy or not disease_info.disease_name: | |
| # Healthy plant recommendations | |
| recommendations.extend([ | |
| "✅ Plant appears healthy - continue current care routine", | |
| "👀 Monitor regularly for any changes in leaf color, spots, or wilting", | |
| "💧 Maintain proper watering schedule - avoid overwatering", | |
| "🌱 Ensure adequate fertilization and soil drainage", | |
| "🛡️ Consider preventive measures during disease-prone seasons", | |
| "🌿 Keep the growing area clean and remove fallen debris" | |
| ]) | |
| else: | |
| # Disease detected - use treatment and prevention from JSON | |
| if disease_info.risk_level == "High": | |
| recommendations.insert(0, "🚨 HIGH RISK DISEASE: Take immediate action to prevent crop loss") | |
| elif disease_info.risk_level == "Medium": | |
| recommendations.insert(0, "⚠️ MEDIUM RISK DISEASE: Prompt treatment recommended") | |
| # Add disease identification | |
| recommendations.append(f"🔬 Disease identified: {disease_info.disease_name}") | |
| # Add treatments from JSON | |
| if disease_info.treatment: | |
| recommendations.append("💊 **TREATMENT RECOMMENDATIONS:**") | |
| for i, treatment in enumerate(disease_info.treatment, 1): | |
| recommendations.append(f" {i}. {treatment}") | |
| else: | |
| recommendations.append("💊 Consult agricultural expert for proper treatment") | |
| # Add prevention measures from JSON | |
| if disease_info.prevention: | |
| recommendations.append("🛡️ **PREVENTION MEASURES:**") | |
| for i, prevention in enumerate(disease_info.prevention, 1): | |
| recommendations.append(f" {i}. {prevention}") | |
| # Add management tips if available | |
| if disease_info.management_tips: | |
| recommendations.append(f"💡 **MANAGEMENT TIP:** {disease_info.management_tips}") | |
| # Add sprayer intervals if available | |
| if disease_info.sprayer_intervals: | |
| recommendations.append(f"🚿 **SPRAYING SCHEDULE:** {disease_info.sprayer_intervals}") | |
| # Add localized tips if available | |
| if disease_info.localized_tips: | |
| recommendations.append(f"🎯 **LOCALIZED TIP:** {disease_info.localized_tips}") | |
| # General disease management recommendations | |
| recommendations.extend([ | |
| "🔒 Isolate affected plants to prevent spread to healthy plants", | |
| "👀 Monitor other plants regularly for similar symptoms", | |
| "🗑️ Remove and destroy infected plant material properly", | |
| "🧼 Sanitize tools and hands after handling infected plants" | |
| ]) | |
| # Add external resources if available | |
| if disease_info.external_resources: | |
| recommendations.append("📚 **EXTERNAL RESOURCES:**") | |
| for resource in disease_info.external_resources: | |
| title = resource.get("title", "Resource") | |
| url = resource.get("url", "") | |
| if url: | |
| recommendations.append(f" 🔗 [{title}]({url})") | |
| else: | |
| recommendations.append(f" 🔖 {title}") | |
| return recommendations | |
| def download_model_from_hf() -> str: | |
| """Download model from Hugging Face Hub""" | |
| try: | |
| logger.info(f"Downloading model from {HF_MODEL_REPO}/{HF_MODEL_FILENAME}") | |
| model_path = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename=HF_MODEL_FILENAME, | |
| cache_dir=HF_CACHE_DIR | |
| ) | |
| logger.info(f"Model downloaded to: {model_path}") | |
| return model_path | |
| except Exception as e: | |
| logger.error(f"Failed to download model: {str(e)}") | |
| raise | |
| def load_model() -> tf.keras.Model: | |
| """Load the Keras model from Hugging Face with optimization""" | |
| try: | |
| model_path = download_model_from_hf() | |
| loaded_model = tf.keras.models.load_model(model_path) | |
| # Compile model for inference optimization | |
| loaded_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') | |
| logger.info("Model loaded and compiled successfully") | |
| return loaded_model | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise | |
| def validate_file_size(file_size: int) -> None: | |
| """Validate uploaded file size""" | |
| max_size_bytes = MAX_FILE_SIZE_MB * 1024 * 1024 | |
| if file_size > max_size_bytes: | |
| raise HTTPException( | |
| status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, | |
| detail=HTTP_MESSAGES["FILE_TOO_LARGE"] | |
| ) | |
| def preprocess_image(image_bytes: bytes) -> np.ndarray: | |
| """Preprocess image for model prediction with enhanced error handling""" | |
| try: | |
| # Validate file size | |
| validate_file_size(len(image_bytes)) | |
| # Open and validate image | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| # Validate image format | |
| if image.format not in ['JPEG', 'PNG', 'BMP', 'TIFF', 'WEBP']: | |
| raise ValueError(f"Unsupported image format: {image.format}") | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Resize image with high-quality resampling | |
| image = image.resize(IMAGE_SIZE, Image.Resampling.LANCZOS) | |
| # Convert to numpy array and normalize | |
| img_array = np.array(image, dtype=np.float32) / 255.0 | |
| # Add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array | |
| except Exception as e: | |
| logger.error(f"Error preprocessing image: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=HTTP_MESSAGES["IMAGE_PROCESSING_FAILED"].format(error=str(e)) | |
| ) | |
| def predict_image(image_bytes: bytes) -> PredictionResponse: | |
| """Make prediction for the uploaded image with enhanced response""" | |
| global model, disease_guide | |
| if model is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail=HTTP_MESSAGES["MODEL_NOT_LOADED"] | |
| ) | |
| try: | |
| # Preprocess image | |
| processed_image = preprocess_image(image_bytes) | |
| # Make prediction | |
| predictions = model.predict(processed_image, verbose=0) | |
| predicted_class_idx = np.argmax(predictions[0]) | |
| confidence = float(predictions[0][predicted_class_idx]) | |
| # Get predicted class as string | |
| predicted_class = str(predicted_class_idx) | |
| # Fetch disease info | |
| disease_data = disease_guide.get(predicted_class) | |
| disease_info = safe_create_disease_info(predicted_class, disease_data) | |
| # Format metadata | |
| clean_name = clean_class_name(predicted_class, disease_data) | |
| confidence_level = get_confidence_level(confidence) | |
| class_id = create_class_id(predicted_class) | |
| # Top 5 predictions | |
| top_indices = np.argsort(predictions[0])[-5:][::-1] | |
| all_predictions = [] | |
| for idx in top_indices: | |
| class_str = str(idx) | |
| class_confidence = float(predictions[0][idx]) | |
| class_info = disease_guide.get(class_str, None) | |
| readable_name = clean_class_name(class_str, class_info) | |
| all_predictions.append({ | |
| "confidence": round(class_confidence, 4), | |
| "label": readable_name, | |
| "confidence_level": get_confidence_level(class_confidence) | |
| }) | |
| # Generate recommendations | |
| recommendations = get_recommendations(predicted_class, confidence, disease_info) | |
| # Final structured response | |
| return PredictionResponse( | |
| success=True, | |
| predicted_class=clean_name, | |
| predicted_class_index=predicted_class_idx, | |
| clean_class_name= clean_name, | |
| message="Prediction successful", | |
| all_predictions=all_predictions, | |
| class_id=class_id, | |
| label=class_id, | |
| confidence=round(confidence, 4), | |
| confidence_level=confidence_level, | |
| disease_info=disease_info, | |
| recommendations=recommendations | |
| ) | |
| except Exception as e: | |
| logger.error(f"Prediction failed: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=HTTP_MESSAGES["PREDICTION_FAILED"].format(error=str(e)) | |
| ) | |
| def is_image_file(filename: str) -> bool: | |
| """Check if file is an image based on extension""" | |
| if not filename: | |
| return False | |
| image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'} | |
| return any(filename.lower().endswith(ext) for ext in image_extensions) | |
| async def lifespan(app: FastAPI): | |
| """Handle startup and shutdown events""" | |
| # Startup | |
| global model, disease_guide | |
| try: | |
| logger.info("Starting up... Loading disease guide and model") | |
| # Load disease guide | |
| disease_guide = load_disease_guide() | |
| # Load model | |
| model = load_model() | |
| # Pre-warm the model with a dummy prediction | |
| dummy_image = np.random.rand(1, *IMAGE_SIZE, 3).astype(np.float32) | |
| _ = model.predict(dummy_image, verbose=0) | |
| logger.info("Model pre-warmed successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize during startup: {str(e)}") | |
| model = None | |
| yield | |
| # Shutdown | |
| logger.info("Shutting down...") | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="Plant Disease Prediction API", | |
| description="API for predicting plant diseases from leaf images using deep learning", | |
| version="2.2.0", | |
| lifespan=lifespan, | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure appropriately for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| """Root endpoint with API information""" | |
| disease_count = len([d for d in disease_guide.values() if d.get("disease_name")]) | |
| healthy_count = len(CLASS_NAMES) - disease_count | |
| return HealthResponse( | |
| status="running", | |
| model_loaded=model is not None, | |
| total_classes=len(CLASS_NAMES), | |
| available_diseases=disease_count, | |
| healthy_classes=healthy_count, | |
| message="Plant Disease Prediction API is running" | |
| ) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| disease_count = len([d for d in disease_guide.values() if d.get("disease_name")]) | |
| healthy_count = len(CLASS_NAMES) - disease_count | |
| return HealthResponse( | |
| status="healthy" if model is not None else "unhealthy", | |
| model_loaded=model is not None, | |
| total_classes=len(CLASS_NAMES), | |
| available_diseases=disease_count, | |
| healthy_classes=healthy_count, | |
| message=HTTP_MESSAGES["MODEL_LOAD_SUCCESS"] if model is not None else HTTP_MESSAGES["MODEL_NOT_LOADED"] | |
| ) | |
| async def predict_plant_disease(file: UploadFile = File(...)): | |
| """ | |
| Predict plant disease from uploaded image | |
| - **file**: Single image file to analyze (max 10MB) | |
| Returns comprehensive prediction with confidence score, disease information, and recommendations | |
| """ | |
| # Validate file | |
| if not file.filename: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="No filename provided" | |
| ) | |
| if not is_image_file(file.filename): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{HTTP_MESSAGES['INVALID_FILE_TYPE']}: {file.filename}" | |
| ) | |
| try: | |
| # Read file content | |
| image_bytes = await file.read() | |
| if len(image_bytes) > MAX_FILE_SIZE_MB * 1024 * 1024: | |
| raise HTTPException( | |
| status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, | |
| detail="Uploaded image exceeds the maximum allowed size of 10MB" | |
| ) | |
| # Make prediction | |
| result = predict_image(image_bytes) | |
| return result | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error processing file {file.filename}: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=HTTP_MESSAGES["IMAGE_PROCESSING_FAILED"].format(error=str(e)) | |
| ) | |
| finally: | |
| # Explicit cleanup for large files | |
| if image_bytes: | |
| del image_bytes | |
| async def get_all_plant_diseases( | |
| crop: Optional[str] = Query(None, description="Filter by crop name (e.g. Apple, Tomato)"), | |
| disease_type: Optional[str] = Query(None, description="Filter by disease type (Fungal, Bacterial, Viral)"), | |
| risk_level: Optional[str] = Query(None, description="Filter by risk level (High, Medium, Low)"), | |
| include_healthy: bool = Query(False, description="Include healthy/unknown classes") | |
| ): | |
| """ | |
| Get all plant diseases with optional filtering | |
| """ | |
| diseases = [] | |
| for class_name, info in disease_guide.items(): | |
| # Skip healthy classes unless specifically requested | |
| if not include_healthy and not info.get("disease_name"): | |
| continue | |
| # Apply filters (only for disease entries) | |
| if info.get("disease_name"): # Only apply filters to actual diseases | |
| if crop and info.get("crop", "").lower() != crop.lower(): | |
| continue | |
| if disease_type and info.get("type", "").lower() != disease_type.lower(): | |
| continue | |
| if risk_level and info.get("risk_level", "").lower() != risk_level.lower(): | |
| continue | |
| diseases.append(SearchResult( | |
| class_name=class_name, | |
| class_id=create_class_id(class_name), | |
| disease_info=safe_create_disease_info(class_name, info if info.get("disease_name") else None) | |
| )) | |
| return diseases | |
| async def search_diseases( | |
| query: str = Query(..., min_length=1, description="Search term"), | |
| limit: int = Query(10, ge=1, le=50, description="Maximum number of results"), | |
| include_healthy: bool = Query(False, description="Include healthy/unknown classes in search") | |
| ): | |
| """ | |
| Search plant diseases with fuzzy matching and relevance scoring | |
| """ | |
| cleaned_query = sanitize_search_query(query) | |
| if not cleaned_query: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Search query cannot be empty" | |
| ) | |
| if len(cleaned_query) < 2: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Search query must be at least 2 characters long" | |
| ) | |
| query_lower = cleaned_query.lower() | |
| exact_matches = [] | |
| fuzzy_candidates = [] | |
| for class_name, info in disease_guide.items(): | |
| # Skip healthy classes unless specifically requested | |
| if not include_healthy and not info.get("disease_name"): | |
| continue | |
| # Build searchable text | |
| searchable_text_parts = [class_name] | |
| if info.get("disease_name"): | |
| searchable_text_parts.extend([ | |
| info.get("disease_name", ""), | |
| info.get("description", ""), | |
| info.get("crop", ""), | |
| info.get("type", ""), | |
| " ".join(info.get("symptoms", [])), | |
| " ".join(info.get("common_names", [])) | |
| ]) | |
| searchable_text = " ".join(searchable_text_parts).lower() | |
| # Check for exact substring matches | |
| if query_lower in searchable_text: | |
| exact_matches.append(SearchResult( | |
| class_name=class_name, | |
| class_id=create_class_id(class_name), | |
| disease_info=safe_create_disease_info(class_name, info if info.get("disease_name") else None) | |
| )) | |
| else: | |
| fuzzy_candidates.append((class_name, info, searchable_text)) | |
| # If we have exact matches, return them | |
| if exact_matches: | |
| return SearchResponse( | |
| results=exact_matches[:limit], | |
| total_results=len(exact_matches), | |
| message=f"Found {len(exact_matches)} exact matches" | |
| ) | |
| # Fuzzy search on candidates | |
| search_texts = [text for _, _, text in fuzzy_candidates] | |
| if search_texts: | |
| fuzzy_matches = process.extract( | |
| query, search_texts, scorer=fuzz.token_sort_ratio, limit=limit | |
| ) | |
| suggestions = [] | |
| for match_text, score, idx in fuzzy_matches: | |
| if score > 60: # Minimum relevance threshold | |
| class_name, info, _ = fuzzy_candidates[idx] | |
| suggestions.append(SearchResult( | |
| class_name=class_name, | |
| class_id=create_class_id(class_name), | |
| disease_info=safe_create_disease_info(class_name, info if info.get("disease_name") else None), | |
| relevance_score=score | |
| )) | |
| return SearchResponse( | |
| results=[], | |
| suggestions=suggestions, | |
| total_results=len(suggestions), | |
| message="No exact matches found. Showing relevant suggestions." if suggestions else "No matches found." | |
| ) | |
| return SearchResponse( | |
| results=[], | |
| suggestions=[], | |
| total_results=0, | |
| message="No matches found." | |
| ) | |
| async def get_disease_by_class_id( | |
| class_id: str = Path(..., description="URL-safe class identifier (use class_id from other endpoints)") | |
| ): | |
| """ | |
| Retrieve detailed information for a specific disease class using URL-safe class ID | |
| """ | |
| try: | |
| # Decode the class_id back to class_name | |
| class_name = decode_class_id(class_id) | |
| # Validate that the class exists in our CLASS_NAMES | |
| if class_name not in CLASS_NAMES: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Class with ID '{class_id}' not found in supported classes." | |
| ) | |
| disease_data = disease_guide.get(class_name, None) | |
| return SearchResult( | |
| class_name=class_name, | |
| class_id=class_id, | |
| disease_info=safe_create_disease_info(class_name, disease_data) | |
| ) | |
| except UnicodeDecodeError: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Invalid class ID format: '{class_id}'" | |
| ) | |
| async def get_disease_by_class_name( | |
| class_name: str = Path(..., description="Exact class name (string number), e.g. '0', '1', '2'") | |
| ): | |
| """ | |
| Retrieve detailed information for a specific disease class by exact class name | |
| (Alternative endpoint for direct class name access) | |
| """ | |
| # Validate that the class exists in our CLASS_NAMES | |
| if not class_name.isdigit() or class_name not in CLASS_NAMES: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Class '{class_name}' not found in supported classes. Supported classes: {', '.join(CLASS_NAMES[:10])}..." | |
| ) | |
| disease_data = disease_guide.get(class_name, None) | |
| return SearchResult( | |
| class_name=class_name, | |
| class_id=create_class_id(class_name), | |
| disease_info=safe_create_disease_info(class_name, disease_data) | |
| ) | |
| async def get_api_stats(): | |
| """Get API statistics and supported classes""" | |
| crops = set() | |
| disease_types = set() | |
| risk_levels = set() | |
| for info in disease_guide.values(): | |
| if info.get("crop"): | |
| crops.add(info["crop"].strip()) | |
| if info.get("type"): | |
| disease_types.add(info["type"]) | |
| if info.get("risk_level"): | |
| risk_levels.add(info["risk_level"]) | |
| return { | |
| "total_classes": len(CLASS_NAMES), | |
| "diseases_in_guide": len([d for d in disease_guide.values() if d.get("disease_name")]), | |
| "healthy_classes": len([d for d in disease_guide.values() if not d.get("disease_name")]), | |
| "supported_crops": sorted(list(crops)), | |
| "disease_types": sorted(list(disease_types)), | |
| "risk_levels": sorted(list(risk_levels)), | |
| "model_loaded": model is not None, | |
| "endpoints": { | |
| "prediction": "/predict", | |
| "all_diseases": "/diseases", | |
| "search": "/search", | |
| "disease_by_id": "/diseases/{class_id}", | |
| "disease_by_name": "/diseases/by-name/{class_name}", | |
| "health": "/health", | |
| "stats": "/stats" | |
| } | |
| } | |
| async def validate_leaf_only(file: UploadFile = File(...)): | |
| """ | |
| Validate if uploaded image contains a leaf without running disease prediction | |
| """ | |
| if not file.filename or not is_image_file(file.filename): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Please upload a valid image file" | |
| ) | |
| try: | |
| image_bytes = await file.read() | |
| processed_image = preprocess_image(image_bytes) | |
| validation_result = validate_leaf_image(processed_image) | |
| return validation_result | |
| except Exception as e: | |
| logger.error(f"Leaf validation failed: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Validation failed: {str(e)}" | |
| ) | |
| finally: | |
| if 'image_bytes' in locals(): | |
| del image_bytes | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000, reload=False) |