2972535Q / app.py
lkchew's picture
Update app.py
06b21e6 verified
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as preprocess_mobilenetv2
from tensorflow.keras.applications.efficientnet import preprocess_input as preprocess_efficientnet
from huggingface_hub import hf_hub_download
import gradio as gr
import numpy as np
from PIL import Image
model_choices = {
"EfficientNetB0": "full_model.weights_efficientnetb0.h5",
"MobileNetV2": "full_model (1).h5"
}
# test_image = {
# "Defective tyre image 1": "lkchew/2972535Q/test_tire_dataset/Defective%20(1)_resized.jpg",
# "Defective tyre image 2": "lkchew/2972535Q/test_tire_dataset/Defective%20(1)_resized.jpg",
# "Good tyre image 1": "lkchew/2972535Q/test_tire_dataset/good%20(40)_resized.jpg",
# "Good tyre image 2": "lkchew/2972535Q/test_tire_dataset/good%20(40)_resized.jpg"
# }
class_names = ['Defective Tyre', 'Good Tyre']
threshold = 0.5
# def load_selected_image(image_path):
# """Loads the selected image from the dropdown."""
# return Image.open(image_path) if image_path else None
def load_model(selected_model):
if selected_model in model_choices:
filename = model_choices[selected_model] # Get filename dynamically
model_file_path = hf_hub_download(repo_id="lkchew/Tire_Defect_Detection", filename=filename)
model = tf.keras.models.load_model(model_file_path)
else:
raise ValueError("Invalid model selected!")
return model
# Preprocessing the image
def preprocess_image(img, selected_model):
img = img.resize((224, 224)) # Resize image to 224x224
img_array = image.img_to_array(img) # Convert image to array
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
# Select appropriate preprocessing based on the model selected
if selected_model == "MobileNetV2":
return preprocess_mobilenetv2(img_array)
elif selected_model == "EfficientNetB0":
return preprocess_efficientnet(img_array)
# Prediction function
def predict_image(img, selected_model):
model = load_model(selected_model) # Dynamically load the selected model
processed_img = preprocess_image(img, selected_model) # Preprocess image based on selected model
prediction = model.predict(processed_img) # Make the prediction
# Determine the result based on prediction threshold
result = class_names[1] if prediction[0][0] >= threshold else class_names[0]
confidence = prediction[0][0] if prediction[0][0] >= threshold else 1 - prediction[0][0]
prediction_text = (
f"Model: {selected_model},\n"
f"Prediction: {result}, Prediction Score: {prediction[0][0]:.4f},\n"
f"Confidence: {confidence:.4f}"
)
#print(prediction_text)
return prediction_text
#Gradio UI with model selection
# gr.Markdown("Tyre Detection Model using Feature Extraction")
# gr_interface = gr.Interface(
# fn=predict_image,
# inputs=[
# gr.Image(type="pil"), # Dropdown for selecting images
# gr.Dropdown(choices=list(model_choices.keys()), label="Select Model") # Model selection
# ],
# outputs="text"
# )
# gr_interface.launch(share=True)
with gr.Blocks() as demo:
gr.Markdown("""
# Tyre Defect Detection Model by Chew Lee Kim
### Binary Classification Criteria:
- A tyre is classified as **Good** if the prediction score is **≥ 0.5**.
- A tyre is classified as **Defective** if the prediction score is **< 0.5**.
""")
image_input = gr.Image(type="pil", label="Selected Image")
model_dropdown = gr.Dropdown(choices=list(model_choices.keys()), label="Select Model")
predict_button = gr.Button("Predict")
output_text = gr.Textbox(label="Prediction Result")
predict_button.click(
fn=predict_image,
inputs=[image_input, model_dropdown],
outputs=output_text
)
demo.launch()