Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import base64 | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for mobile app | |
| # Define image dimensions | |
| IMG_HEIGHT = 150 | |
| IMG_WIDTH = 150 | |
| # All 70 class names | |
| class_names = [ | |
| 'Algal Leaf Spot (Jackfruit)', 'Anthracnose (Mango)', 'Aphids (Cotton)', | |
| 'Apple scab (Apple)', 'Bacterial Blight (Cotton)', 'Bacterial Canker (Mango)', | |
| 'Bacterial Leaf Spot (Pumpkin)', 'Bacterial spot (Peach)', 'Bacterial spot (Pepper, bell)', | |
| 'Bacterial spot (Tomato)', 'BacterialBlights (Sugarcane)', 'Black Rot (Cauliflower)', | |
| 'Black Spot (Jackfruit)', 'Black rot (Apple)', 'Black rot (Grape)', | |
| 'BrownSpot (Rice)', 'Cedar apple rust (Apple)', 'Cercospora leaf spot Gray leaf spot (Corn (maize))', | |
| 'Common rust (Corn (maize))', 'Cutting Weevil (Mango)', 'Die Back (Mango)', | |
| 'Downy Mildew (Pumpkin)', 'Early blight (Potato)', 'Early blight (Tomato)', | |
| 'Esca (Black Measles) (Grape)', 'Gall Midge (Mango)', 'Haunglongbing (Citrus greening) (Orange)', | |
| 'Healthy (Cauliflower)', 'Healthy (Cotton)', 'Healthy (Jackfruit)', | |
| 'Healthy (Mango)', 'Healthy (Rice)', 'Healthy (Sugarcane)', | |
| 'Healthy Leaf (Pumpkin)', 'Hispa (Rice)', 'Late blight (Potato)', | |
| 'Late blight (Tomato)', 'Leaf Mold (Tomato)', 'Leaf blight (Isariopsis Leaf Spot) (Grape)', | |
| 'Leaf scorch (Strawberry)', 'LeafBlast (Rice)', 'Mosaic (Sugarcane)', | |
| 'Mosaic Disease (Pumpkin)', 'Northern Leaf Blight (Corn (maize))', 'Powdery Mildew (Cotton)', | |
| 'Powdery Mildew (Mango)', 'Powdery Mildew (Pumpkin)', 'Powdery mildew (Cherry (including sour))', | |
| 'RedRot (Sugarcane)', 'Rust (Sugarcane)', 'Septoria leaf spot (Tomato)', | |
| 'Sooty Mould (Mango)', 'Spider mites Two-spotted spider mite (Tomato)', 'Target Spot (Tomato)', | |
| 'Target spot (Cotton)', 'Tomato Yellow Leaf Curl Virus (Tomato)', 'Tomato mosaic virus (Tomato)', | |
| 'Unknown Disease', 'Yellow (Sugarcane)', 'healthy (Apple)', | |
| 'healthy (Blueberry)', 'healthy (Cherry (including sour))', 'healthy (Corn (maize))', | |
| 'healthy (Grape)', 'healthy (Peach)', 'healthy (Pepper, bell)', | |
| 'healthy (Potato)', 'healthy (Raspberry)', 'healthy (Soybean)', | |
| 'healthy (Strawberry)', 'healthy (Tomato)' | |
| ] | |
| # Load model | |
| print("Loading model...") | |
| model = tf.saved_model.load('./plant_disease_savemodel') | |
| infer = model.signatures["serving_default"] | |
| print("✅ Model loaded successfully") | |
| def health(): | |
| return jsonify({"status": "healthy", "model_loaded": True}) | |
| def predict(): | |
| try: | |
| data = request.get_json() | |
| # Get base64 image from request | |
| image_data = data.get('image') | |
| if not image_data: | |
| return jsonify({"error": "No image provided"}), 400 | |
| # Remove data URL prefix if present | |
| if ',' in image_data: | |
| image_data = image_data.split(',')[1] | |
| # Decode base64 image | |
| image_bytes = base64.b64decode(image_data) | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| # Ensure RGB mode | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| # Resize to model input size | |
| img = img.resize((IMG_WIDTH, IMG_HEIGHT)) | |
| # Convert to array and normalize | |
| img_array = np.array(img, dtype=np.float32) | |
| img_array = img_array / 255.0 | |
| # Add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) | |
| # Make prediction | |
| predictions = infer(tf.constant(img_array)) | |
| # Get the output tensor | |
| if 'output_0' in predictions: | |
| output = predictions['output_0'].numpy() | |
| elif 'dense_1' in predictions: | |
| output = predictions['dense_1'].numpy() | |
| elif 'dense' in predictions: | |
| output = predictions['dense'].numpy() | |
| else: | |
| output = list(predictions.values())[0].numpy() | |
| # Create predictions dictionary | |
| predictions_dict = {} | |
| for i, class_name in enumerate(class_names): | |
| if i < len(output[0]): | |
| predictions_dict[class_name] = float(output[0][i]) | |
| # Get top prediction | |
| top_class = max(predictions_dict.items(), key=lambda x: x[1]) | |
| print(f"Prediction: {top_class[0]} ({top_class[1]*100:.2f}%)") | |
| return jsonify({ | |
| "success": True, | |
| "predictions": predictions_dict, | |
| "top_prediction": { | |
| "class": top_class[0], | |
| "confidence": float(top_class[1]) | |
| } | |
| }) | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) | |