Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import os | |
| # Define image dimensions | |
| IMG_HEIGHT = 150 | |
| IMG_WIDTH = 150 | |
| # All 70 class names from the trained model | |
| class_names = [ | |
| 'Algal Leaf Spot (Jackfruit)', | |
| 'Anthracnose (Mango)', | |
| 'Aphids (Cotton)', | |
| 'Apple scab (Apple)', | |
| 'Bacterial Blight (Cotton)', | |
| 'Bacterial Canker (Mango)', | |
| 'Bacterial Leaf Spot (Pumpkin)', | |
| 'Bacterial spot (Peach)', | |
| 'Bacterial spot (Pepper, bell)', | |
| 'Bacterial spot (Tomato)', | |
| 'BacterialBlights (Sugarcane)', | |
| 'Black Rot (Cauliflower)', | |
| 'Black Spot (Jackfruit)', | |
| 'Black rot (Apple)', | |
| 'Black rot (Grape)', | |
| 'BrownSpot (Rice)', | |
| 'Cedar apple rust (Apple)', | |
| 'Cercospora leaf spot Gray leaf spot (Corn (maize))', | |
| 'Common rust (Corn (maize))', | |
| 'Cutting Weevil (Mango)', | |
| 'Die Back (Mango)', | |
| 'Downy Mildew (Pumpkin)', | |
| 'Early blight (Potato)', | |
| 'Early blight (Tomato)', | |
| 'Esca (Black Measles) (Grape)', | |
| 'Gall Midge (Mango)', | |
| 'Haunglongbing (Citrus greening) (Orange)', | |
| 'Healthy (Cauliflower)', | |
| 'Healthy (Cotton)', | |
| 'Healthy (Jackfruit)', | |
| 'Healthy (Mango)', | |
| 'Healthy (Rice)', | |
| 'Healthy (Sugarcane)', | |
| 'Healthy Leaf (Pumpkin)', | |
| 'Hispa (Rice)', | |
| 'Late blight (Potato)', | |
| 'Late blight (Tomato)', | |
| 'Leaf Mold (Tomato)', | |
| 'Leaf blight (Isariopsis Leaf Spot) (Grape)', | |
| 'Leaf scorch (Strawberry)', | |
| 'LeafBlast (Rice)', | |
| 'Mosaic (Sugarcane)', | |
| 'Mosaic Disease (Pumpkin)', | |
| 'Northern Leaf Blight (Corn (maize))', | |
| 'Powdery Mildew (Cotton)', | |
| 'Powdery Mildew (Mango)', | |
| 'Powdery Mildew (Pumpkin)', | |
| 'Powdery mildew (Cherry (including sour))', | |
| 'RedRot (Sugarcane)', | |
| 'Rust (Sugarcane)', | |
| 'Septoria leaf spot (Tomato)', | |
| 'Sooty Mould (Mango)', | |
| 'Spider mites Two-spotted spider mite (Tomato)', | |
| 'Target Spot (Tomato)', | |
| 'Target spot (Cotton)', | |
| 'Tomato Yellow Leaf Curl Virus (Tomato)', | |
| 'Tomato mosaic virus (Tomato)', | |
| 'Unknown Disease', | |
| 'Yellow (Sugarcane)', | |
| 'healthy (Apple)', | |
| 'healthy (Blueberry)', | |
| 'healthy (Cherry (including sour))', | |
| 'healthy (Corn (maize))', | |
| 'healthy (Grape)', | |
| 'healthy (Peach)', | |
| 'healthy (Pepper, bell)', | |
| 'healthy (Potato)', | |
| 'healthy (Raspberry)', | |
| 'healthy (Soybean)', | |
| 'healthy (Strawberry)', | |
| 'healthy (Tomato)' | |
| ] | |
| # Load the TensorFlow SavedModel | |
| print("Loading model...") | |
| print(f"Current directory: {os.getcwd()}") | |
| print(f"Files in current directory: {os.listdir('.')}") | |
| model = None | |
| infer = None | |
| try: | |
| # Try different possible model paths | |
| possible_paths = [ | |
| './plant_disease_savemodel', | |
| './plant_disease_savedmodel', | |
| 'plant_disease_savemodel', | |
| 'plant_disease_savedmodel' | |
| ] | |
| model_path = None | |
| for path in possible_paths: | |
| if os.path.exists(path): | |
| model_path = path | |
| print(f"Found model at: {model_path}") | |
| break | |
| if model_path is None: | |
| raise FileNotFoundError("Model directory not found. Please ensure 'plant_disease_savemodel' folder is uploaded.") | |
| # Check if model files exist | |
| model_files = os.listdir(model_path) | |
| print(f"Files in model directory: {model_files}") | |
| # Load the model | |
| model = tf.saved_model.load(model_path) | |
| infer = model.signatures["serving_default"] | |
| print(f"✅ Model loaded successfully from {model_path}") | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| model = None | |
| infer = None | |
| def predict_disease(image): | |
| """ | |
| Predict plant disease from an image | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| dict: Dictionary with class names as keys and confidence scores as values | |
| Format compatible with CropGuard mobile app | |
| """ | |
| if model is None or infer is None: | |
| return { | |
| "Error": 1.0, | |
| "Message": "Model not loaded. Please check the model files." | |
| } | |
| try: | |
| # Convert to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| img = Image.fromarray(image.astype('uint8'), 'RGB') | |
| else: | |
| img = image | |
| # Ensure RGB mode | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| # Resize to model input size (150x150 as per training) | |
| img = img.resize((IMG_WIDTH, IMG_HEIGHT)) | |
| # Convert to array and normalize | |
| img_array = np.array(img, dtype=np.float32) | |
| img_array = img_array / 255.0 # Normalize to [0, 1] | |
| # Add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) | |
| # Make prediction | |
| predictions = infer(tf.constant(img_array)) | |
| # Get the output tensor (try different possible keys) | |
| if 'output_0' in predictions: | |
| output = predictions['output_0'].numpy() | |
| elif 'dense_1' in predictions: | |
| output = predictions['dense_1'].numpy() | |
| elif 'dense' in predictions: | |
| output = predictions['dense'].numpy() | |
| else: | |
| # Use the first output | |
| output = list(predictions.values())[0].numpy() | |
| # Get predictions for all classes | |
| predictions_dict = {} | |
| for i, class_name in enumerate(class_names): | |
| if i < len(output[0]): | |
| predictions_dict[class_name] = float(output[0][i]) | |
| # Log top prediction for debugging | |
| top_class = max(predictions_dict.items(), key=lambda x: x[1]) | |
| print(f"Top prediction: {top_class[0]} ({top_class[1]*100:.2f}%)") | |
| # Return in format compatible with Gradio Label output | |
| # Gradio will automatically show top predictions | |
| # Mobile app expects: { "class_name": confidence, ... } | |
| return predictions_dict | |
| except Exception as e: | |
| print(f"Prediction error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return { | |
| "Error": 1.0, | |
| "Message": f"Prediction failed: {str(e)}" | |
| } | |
| # Create Gradio interface | |
| title = "🌱 CropGuard Tech - Plant Disease Detection" | |
| description = """ | |
| Upload an image of a plant leaf to detect diseases using AI. | |
| **Supported Crops:** Apple, Blueberry, Cauliflower, Cherry, Corn, Cotton, Grape, Jackfruit, Mango, Orange, Peach, Pepper, Potato, Pumpkin, Raspberry, Rice, Soybean, Strawberry, Sugarcane, Tomato | |
| **Model Specs:** | |
| - 70 disease classes | |
| - 95%+ accuracy | |
| - CNN architecture | |
| - Trained on 10,000+ images | |
| """ | |
| article = """ | |
| ### About CropGuard Tech | |
| This AI model was trained on Google Colab using a comprehensive plant disease dataset from Kaggle. | |
| It can identify 70 different plant diseases across 19+ crop varieties. | |
| **Model Repository:** [View on Hugging Face](https://huggingface.co/4lph4v3rs3/plant-disease-classification-model) | |
| """ | |
| examples = [ | |
| # You can add example images here if you have them | |
| ] | |
| # Create the interface | |
| iface = gr.Interface( | |
| fn=predict_disease, | |
| inputs=gr.Image(label="Upload Plant Leaf Image"), | |
| outputs=gr.Label(num_top_classes=5, label="Disease Predictions"), | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| iface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| ssr_mode=False # Disable SSR to avoid hot reload errors | |
| ) | |