Update app.py
Browse files
app.py
CHANGED
|
@@ -1,60 +1,92 @@
|
|
| 1 |
import os
|
| 2 |
-
import
|
| 3 |
import tensorflow as tf
|
| 4 |
-
from tensorflow.keras.models import load_model
|
| 5 |
from tensorflow.keras.preprocessing import image
|
| 6 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
|
| 31 |
-
img_array /= 255.0 # Normalize the image if required by the model
|
| 32 |
-
|
| 33 |
-
# Predict the class of the tomato leaf disease
|
| 34 |
-
predictions = model.predict(img_array)
|
| 35 |
-
|
| 36 |
-
# Assuming the model returns probabilities, get the class with the highest probability
|
| 37 |
-
predicted_class = np.argmax(predictions, axis=1)[0]
|
| 38 |
-
|
| 39 |
-
# You can map the class index to the disease name if required
|
| 40 |
-
disease_classes = ['Bacterial_spot', 'Early_blight', 'Late_blight', 'Tomato_mosaic_virus', 'Tomato_Yellow_Leaf_Curl_Virus']
|
| 41 |
-
predicted_disease = disease_classes[predicted_class]
|
| 42 |
-
|
| 43 |
-
return predicted_disease
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import numpy as np
|
| 3 |
import tensorflow as tf
|
|
|
|
| 4 |
from tensorflow.keras.preprocessing import image
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
# Suppress TensorFlow warnings
|
| 8 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 9 |
+
device = "cuda" if tf.test.is_gpu_available() else "cpu"
|
| 10 |
+
print(f"Running on: {device.upper()}")
|
| 11 |
+
|
| 12 |
+
# Load the trained tomato disease detection model
|
| 13 |
+
model = tf.keras.models.load_model("Tomato_Leaf_Disease_Model.h5")
|
| 14 |
+
|
| 15 |
+
# Disease categories
|
| 16 |
+
class_labels = [
|
| 17 |
+
"Tomato Bacterial Spot",
|
| 18 |
+
"Tomato Early Blight",
|
| 19 |
+
"Tomato Late Blight",
|
| 20 |
+
"Tomato Mosaic Virus",
|
| 21 |
+
"Tomato Yellow Leaf Curl Virus"
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
# Image preprocessing function
|
| 25 |
+
def preprocess_image(img):
|
| 26 |
+
img = img.resize((224, 224)) # Resize for model input
|
| 27 |
+
img = image.img_to_array(img) / 255.0 # Normalize
|
| 28 |
+
return np.expand_dims(img, axis=0) # Add batch dimension
|
| 29 |
|
| 30 |
+
# Temperature Scaling: Adjusts predictions using a temperature parameter.
|
| 31 |
+
def apply_temperature_scaling(prediction, temperature):
|
| 32 |
+
# Avoid log(0) by adding a small epsilon
|
| 33 |
+
eps = 1e-8
|
| 34 |
+
scaled_logits = np.log(np.maximum(prediction, eps)) / temperature
|
| 35 |
+
exp_logits = np.exp(scaled_logits)
|
| 36 |
+
scaled_probs = exp_logits / np.sum(exp_logits)
|
| 37 |
+
return scaled_probs
|
| 38 |
+
|
| 39 |
+
# Min-Max Normalization: Scales the raw confidence based on provided min and max values.
|
| 40 |
+
def apply_min_max_scaling(confidence, min_conf, max_conf):
|
| 41 |
+
norm = (confidence - min_conf) / (max_conf - min_conf) * 100
|
| 42 |
+
norm = np.clip(norm, 0, 100)
|
| 43 |
+
return norm
|
| 44 |
+
|
| 45 |
+
# Main detection function with adjustable confidence scaling
|
| 46 |
+
def detect_disease_scaled(img, scaling_method, temperature, min_conf, max_conf):
|
| 47 |
+
processed_img = preprocess_image(img)
|
| 48 |
+
prediction = model.predict(processed_img)[0] # Get prediction for single image
|
| 49 |
+
raw_confidence = np.max(prediction) * 100
|
| 50 |
+
class_idx = np.argmax(prediction)
|
| 51 |
+
disease_name = class_labels[class_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
if scaling_method == "Temperature Scaling":
|
| 54 |
+
scaled_probs = apply_temperature_scaling(prediction, temperature)
|
| 55 |
+
adjusted_confidence = np.max(scaled_probs) * 100
|
| 56 |
+
elif scaling_method == "Min-Max Normalization":
|
| 57 |
+
adjusted_confidence = apply_min_max_scaling(raw_confidence, min_conf, max_conf)
|
| 58 |
+
else:
|
| 59 |
+
adjusted_confidence = raw_confidence
|
| 60 |
+
|
| 61 |
+
# Return both the adjusted result and the raw confidence value
|
| 62 |
+
result = f"{disease_name} (Confidence: {adjusted_confidence:.2f}%)"
|
| 63 |
+
raw_text = f"Raw Confidence: {raw_confidence:.2f}%"
|
| 64 |
+
return result, raw_text
|
| 65 |
+
|
| 66 |
+
# Gradio UI
|
| 67 |
+
with gr.Blocks() as demo:
|
| 68 |
+
gr.Markdown("# 🍅 Tomato Sentry: Disease Detection with Confidence Adjustment")
|
| 69 |
+
|
| 70 |
+
with gr.Row():
|
| 71 |
+
with gr.Column():
|
| 72 |
+
image_input = gr.Image(type="pil", label="Upload a Tomato Leaf Image")
|
| 73 |
+
scaling_method = gr.Radio(
|
| 74 |
+
["Temperature Scaling", "Min-Max Normalization"],
|
| 75 |
+
label="Confidence Scaling Method",
|
| 76 |
+
value="Temperature Scaling"
|
| 77 |
+
)
|
| 78 |
+
temperature_slider = gr.Slider(0.5, 2.0, step=0.1, label="Temperature", value=1.0)
|
| 79 |
+
min_conf_slider = gr.Slider(0, 100, step=1, label="Min Confidence", value=20)
|
| 80 |
+
max_conf_slider = gr.Slider(0, 100, step=1, label="Max Confidence", value=90)
|
| 81 |
+
detect_button = gr.Button("Detect Disease")
|
| 82 |
+
with gr.Column():
|
| 83 |
+
disease_output = gr.Textbox(label="Detected Disease & Adjusted Confidence", interactive=False)
|
| 84 |
+
raw_confidence_output = gr.Textbox(label="Raw Confidence", interactive=False)
|
| 85 |
+
|
| 86 |
+
detect_button.click(
|
| 87 |
+
detect_disease_scaled,
|
| 88 |
+
inputs=[image_input, scaling_method, temperature_slider, min_conf_slider, max_conf_slider],
|
| 89 |
+
outputs=[disease_output, raw_confidence_output]
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
demo.launch()
|