File size: 3,845 Bytes
5636459
 
 
 
 
 
 
 
 
 
2f54c6b
67e4602
2f54c6b
5636459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
import numpy as np
import tensorflow as tf
from PIL import Image
from flask import Flask, request, jsonify
from flask_cors import CORS
import io
from huggingface_hub import hf_hub_download

# Set Hugging Face cache to a folder inside the container
os.environ['HF_HUB_CACHE'] = '/tmp/hf_cache'

# Initialize Flask app
app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

# Load model and class indices
working_dir = os.path.dirname(os.path.abspath(__file__))
#model_path = os.path.join(working_dir, "trained_model", "plant_disease_model.tflite")
model_path = hf_hub_download(
    repo_id="sidd-harth011/checkingPDRMod",  # ✅ your repo
    filename="plant_disease_model.tflite"
)

# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Load class indices
class_indices_path = os.path.join(working_dir, "class_indices.json")
with open(class_indices_path, 'r') as f:
    class_indices = json.load(f)

# -----------------------------
# Preprocessing function
# -----------------------------
def load_and_preprocess_image(image, target_size=(224, 224)):
    img = image.resize(target_size)
    img_array = np.array(img, dtype=np.float32)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = img_array / 255.0
    return img_array

# -----------------------------
# Function to clean label
# -----------------------------
def clean_label(label: str) -> str:
    if "___" in label:
        label = label.split("___")[-1]
    return label.replace("_", " ").title()

# -----------------------------
# Prediction function
# -----------------------------
def predict_image_class(image):
    preprocessed_img = load_and_preprocess_image(image)
    interpreter.set_tensor(input_details[0]['index'], preprocessed_img)
    interpreter.invoke()
    predictions = interpreter.get_tensor(output_details[0]['index'])
    predicted_class_index = np.argmax(predictions, axis=1)[0]
    predicted_class_name = class_indices[str(predicted_class_index)]
    predicted_class_name = clean_label(predicted_class_name)
    
    # Get confidence score
    confidence = float(predictions[0][predicted_class_index])
    
    return predicted_class_name, confidence

# -----------------------------
# API endpoint for image classification
# -----------------------------
@app.route('/predict', methods=['POST'])
def predict():
    try:
        # Check if image is in the request
        if 'image' not in request.files:
            return jsonify({'error': 'No image provided'}), 400
        
        # Get the image file
        image_file = request.files['image']
        
        # Check if filename is empty
        if image_file.filename == '':
            return jsonify({'error': 'No image selected'}), 400
        
        # Read and process the image
        image = Image.open(io.BytesIO(image_file.read()))
        
        # Make prediction
        predicted_class, confidence = predict_image_class(image)
        
        # Return prediction as JSON
        return jsonify({
            'prediction': predicted_class,
            'confidence': confidence,
            'status': 'success'
        })
        
    except Exception as e:
        return jsonify({'error': str(e), 'status': 'error'}), 500

# -----------------------------
# Health check endpoint
# -----------------------------
@app.route('/health', methods=['GET'])
def health_check():
    return jsonify({'status': 'healthy', 'message': 'Plant Disease Classification API is running'})

# -----------------------------
# Run the Flask app
# -----------------------------
if __name__ == '__main__':
    # You can change the host and port as needed
    app.run(host='0.0.0.0', port=7860, debug=False)