sidd-harth011 commited on
Commit
5636459
·
1 Parent(s): 8e40aa8

intial docker deployment

Browse files
Files changed (4) hide show
  1. Dockerfile +18 -0
  2. app.py +114 -0
  3. class_indices.json +1 -0
  4. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use official Python image
2
+ FROM python:3.11-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Copy requirements and install
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Copy the rest of the app
12
+ COPY . .
13
+
14
+ # Expose the port HF Spaces uses
15
+ EXPOSE 7860
16
+
17
+ # Run the Flask app
18
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from PIL import Image
6
+ from flask import Flask, request, jsonify
7
+ from flask_cors import CORS
8
+ import io
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ # Initialize Flask app
12
+ app = Flask(__name__)
13
+ CORS(app) # Enable CORS for all routes
14
+
15
+ # Load model and class indices
16
+ working_dir = os.path.dirname(os.path.abspath(__file__))
17
+ #model_path = os.path.join(working_dir, "trained_model", "plant_disease_model.tflite")
18
+ model_path = hf_hub_download(
19
+ repo_id="sidd-harth011/checkingPDRMod", # ✅ your repo
20
+ filename="plant_disease_model.tflite"
21
+ )
22
+
23
+ # Load the TFLite model
24
+ interpreter = tf.lite.Interpreter(model_path=model_path)
25
+ interpreter.allocate_tensors()
26
+ input_details = interpreter.get_input_details()
27
+ output_details = interpreter.get_output_details()
28
+
29
+ # Load class indices
30
+ class_indices_path = os.path.join(working_dir, "class_indices.json")
31
+ with open(class_indices_path, 'r') as f:
32
+ class_indices = json.load(f)
33
+
34
+ # -----------------------------
35
+ # Preprocessing function
36
+ # -----------------------------
37
+ def load_and_preprocess_image(image, target_size=(224, 224)):
38
+ img = image.resize(target_size)
39
+ img_array = np.array(img, dtype=np.float32)
40
+ img_array = np.expand_dims(img_array, axis=0)
41
+ img_array = img_array / 255.0
42
+ return img_array
43
+
44
+ # -----------------------------
45
+ # Function to clean label
46
+ # -----------------------------
47
+ def clean_label(label: str) -> str:
48
+ if "___" in label:
49
+ label = label.split("___")[-1]
50
+ return label.replace("_", " ").title()
51
+
52
+ # -----------------------------
53
+ # Prediction function
54
+ # -----------------------------
55
+ def predict_image_class(image):
56
+ preprocessed_img = load_and_preprocess_image(image)
57
+ interpreter.set_tensor(input_details[0]['index'], preprocessed_img)
58
+ interpreter.invoke()
59
+ predictions = interpreter.get_tensor(output_details[0]['index'])
60
+ predicted_class_index = np.argmax(predictions, axis=1)[0]
61
+ predicted_class_name = class_indices[str(predicted_class_index)]
62
+ predicted_class_name = clean_label(predicted_class_name)
63
+
64
+ # Get confidence score
65
+ confidence = float(predictions[0][predicted_class_index])
66
+
67
+ return predicted_class_name, confidence
68
+
69
+ # -----------------------------
70
+ # API endpoint for image classification
71
+ # -----------------------------
72
+ @app.route('/predict', methods=['POST'])
73
+ def predict():
74
+ try:
75
+ # Check if image is in the request
76
+ if 'image' not in request.files:
77
+ return jsonify({'error': 'No image provided'}), 400
78
+
79
+ # Get the image file
80
+ image_file = request.files['image']
81
+
82
+ # Check if filename is empty
83
+ if image_file.filename == '':
84
+ return jsonify({'error': 'No image selected'}), 400
85
+
86
+ # Read and process the image
87
+ image = Image.open(io.BytesIO(image_file.read()))
88
+
89
+ # Make prediction
90
+ predicted_class, confidence = predict_image_class(image)
91
+
92
+ # Return prediction as JSON
93
+ return jsonify({
94
+ 'prediction': predicted_class,
95
+ 'confidence': confidence,
96
+ 'status': 'success'
97
+ })
98
+
99
+ except Exception as e:
100
+ return jsonify({'error': str(e), 'status': 'error'}), 500
101
+
102
+ # -----------------------------
103
+ # Health check endpoint
104
+ # -----------------------------
105
+ @app.route('/health', methods=['GET'])
106
+ def health_check():
107
+ return jsonify({'status': 'healthy', 'message': 'Plant Disease Classification API is running'})
108
+
109
+ # -----------------------------
110
+ # Run the Flask app
111
+ # -----------------------------
112
+ if __name__ == '__main__':
113
+ # You can change the host and port as needed
114
+ app.run(host='0.0.0.0', port=7860, debug=False)
class_indices.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "Apple___Apple_scab", "1": "Apple___Black_rot", "2": "Apple___Cedar_apple_rust", "3": "Apple___healthy", "4": "Blueberry___healthy", "5": "Cherry_(including_sour)___Powdery_mildew", "6": "Cherry_(including_sour)___healthy", "7": "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot", "8": "Corn_(maize)___Common_rust_", "9": "Corn_(maize)___Northern_Leaf_Blight", "10": "Corn_(maize)___healthy", "11": "Grape___Black_rot", "12": "Grape___Esca_(Black_Measles)", "13": "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)", "14": "Grape___healthy", "15": "Orange___Haunglongbing_(Citrus_greening)", "16": "Peach___Bacterial_spot", "17": "Peach___healthy", "18": "Pepper,_bell___Bacterial_spot", "19": "Pepper,_bell___healthy", "20": "Potato___Early_blight", "21": "Potato___Late_blight", "22": "Potato___healthy", "23": "Raspberry___healthy", "24": "Soybean___healthy", "25": "Squash___Powdery_mildew", "26": "Strawberry___Leaf_scorch", "27": "Strawberry___healthy", "28": "Tomato___Bacterial_spot", "29": "Tomato___Early_blight", "30": "Tomato___Late_blight", "31": "Tomato___Leaf_Mold", "32": "Tomato___Septoria_leaf_spot", "33": "Tomato___Spider_mites Two-spotted_spider_mite", "34": "Tomato___Target_Spot", "35": "Tomato___Tomato_Yellow_Leaf_Curl_Virus", "36": "Tomato___Tomato_mosaic_virus", "37": "Tomato___healthy"}
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ tensorflow
2
+ numpy
3
+ flask
4
+ flask-cors
5
+ pillow
6
+ huggingface-hub