Spaces:
Sleeping
Sleeping
File size: 3,845 Bytes
5636459 2f54c6b 67e4602 2f54c6b 5636459 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
import json
import numpy as np
import tensorflow as tf
from PIL import Image
from flask import Flask, request, jsonify
from flask_cors import CORS
import io
from huggingface_hub import hf_hub_download
# Set Hugging Face cache to a folder inside the container
os.environ['HF_HUB_CACHE'] = '/tmp/hf_cache'
# Initialize Flask app
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Load model and class indices
working_dir = os.path.dirname(os.path.abspath(__file__))
#model_path = os.path.join(working_dir, "trained_model", "plant_disease_model.tflite")
model_path = hf_hub_download(
repo_id="sidd-harth011/checkingPDRMod", # ✅ your repo
filename="plant_disease_model.tflite"
)
# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Load class indices
class_indices_path = os.path.join(working_dir, "class_indices.json")
with open(class_indices_path, 'r') as f:
class_indices = json.load(f)
# -----------------------------
# Preprocessing function
# -----------------------------
def load_and_preprocess_image(image, target_size=(224, 224)):
img = image.resize(target_size)
img_array = np.array(img, dtype=np.float32)
img_array = np.expand_dims(img_array, axis=0)
img_array = img_array / 255.0
return img_array
# -----------------------------
# Function to clean label
# -----------------------------
def clean_label(label: str) -> str:
if "___" in label:
label = label.split("___")[-1]
return label.replace("_", " ").title()
# -----------------------------
# Prediction function
# -----------------------------
def predict_image_class(image):
preprocessed_img = load_and_preprocess_image(image)
interpreter.set_tensor(input_details[0]['index'], preprocessed_img)
interpreter.invoke()
predictions = interpreter.get_tensor(output_details[0]['index'])
predicted_class_index = np.argmax(predictions, axis=1)[0]
predicted_class_name = class_indices[str(predicted_class_index)]
predicted_class_name = clean_label(predicted_class_name)
# Get confidence score
confidence = float(predictions[0][predicted_class_index])
return predicted_class_name, confidence
# -----------------------------
# API endpoint for image classification
# -----------------------------
@app.route('/predict', methods=['POST'])
def predict():
try:
# Check if image is in the request
if 'image' not in request.files:
return jsonify({'error': 'No image provided'}), 400
# Get the image file
image_file = request.files['image']
# Check if filename is empty
if image_file.filename == '':
return jsonify({'error': 'No image selected'}), 400
# Read and process the image
image = Image.open(io.BytesIO(image_file.read()))
# Make prediction
predicted_class, confidence = predict_image_class(image)
# Return prediction as JSON
return jsonify({
'prediction': predicted_class,
'confidence': confidence,
'status': 'success'
})
except Exception as e:
return jsonify({'error': str(e), 'status': 'error'}), 500
# -----------------------------
# Health check endpoint
# -----------------------------
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({'status': 'healthy', 'message': 'Plant Disease Classification API is running'})
# -----------------------------
# Run the Flask app
# -----------------------------
if __name__ == '__main__':
# You can change the host and port as needed
app.run(host='0.0.0.0', port=7860, debug=False) |