2972535Q / groupapp.py
lkchew's picture
Rename app.py to groupapp.py
0de8ba8 verified
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as preprocess_mobilenetv2
from tensorflow.keras.applications.efficientnet import preprocess_input as preprocess_efficientnet
from huggingface_hub import hf_hub_download
import gradio as gr
import numpy as np
from PIL import Image
import requests
from io import BytesIO
# Define available models
model_choices = {
"Fine Tuning on VGG16 by Edwin": "best_finetune_checkpoint.weights.h5",
"Feature Extraction on EfficientNetB0 by Seng Kuan": "full_model.weights_SK.h5",
"Feature Extraction on MobileNetV2 by Lee Kim": "full_model_mobilenetv2.h5"
}
# Define test images (URLs or file paths)
test_image = {
"Defective tyre image": "https://huggingface.co/spaces/lkchew/2972535Q/resolve/main/test_tire_dataset/Defective%20(454)_SK.jpg",
"Good tyre image": "https://huggingface.co/spaces/lkchew/2972535Q/resolve/main/test_tire_dataset/good%20(160)_SK.jpg"
}
# Class names
class_names = ['Defective Tyre', 'Good Tyre']
threshold = 0.5
# Build the model (must match training architecture)
def build_model():
debug_info = f"into build model \n"
try:
base_model = VGG16(weights=None, include_top=False, input_shape=image_size + (3,))
inputs = keras.layers.Input(shape=image_size + (3,))
x1 = preprocess_input(inputs)
x1 = base_model(x1, training=False)
x1 = keras.layers.Flatten()(x1) # Flatten instead of GlobalAveragePooling2D
x1 = keras.layers.Dropout(rate=0.5)(x1)
x1 = keras.layers.Dense(units=256, activation="relu")(x1)
x1 = keras.layers.Dropout(rate=0.5)(x1)
outputs = keras.layers.Dense(units=1, activation="sigmoid")(x1)
model = keras.models.Model(inputs=[inputs], outputs=[outputs])
# Compile the model (same as training)
model.compile(
loss="binary_crossentropy",
optimizer=keras.optimizers.Adam(learning_rate=0.001),
metrics=["accuracy"],
)
return model
except Exception as e:
return debug_info
# Function to load the selected image (handle both local and online images)
def load_selected_image(image_path):
"""Loads an image from a local path or a URL."""
try:
if image_path.startswith("http"): # If it's an online image
response = requests.get(image_path)
img = Image.open(BytesIO(response.content))
else: # Local image
img = Image.open(image_path)
return img
except Exception as e:
debug_info = f"Error loading image: {e}"
return debug_info
# Function to load the selected model
def load_model(selected_model):
"""Loads the selected deep learning model."""
debug_info = f"into load model\n"
try:
if selected_model == "Fine Tuning on VGG16 by Edwin":
# Load the model and weights
model = build_model()
model.load_weights("best_finetune_checkpoint.weights.h5")
# model = build_model()
# debug_info = f"inside vgg16 if loop\n"
# # URL of weights
# weights_url = "https://huggingface.co/spaces/lkchew/2972535Q/resolve/main/best_finetune_checkpoint.weights.h5"
# weights_filename = "best_finetune_checkpoint.weights.h5"
# # Download weights file
# response = requests.get(weights_url)
# if response.status_code != 200:
# debug_info = f"❌ Error: Failed to download weights ({response.status_code}).\n"
# return debug_info
# with open(weights_filename, "wb") as f:
# f.write(response.content)
# # Load weights
# model.load_weights(weights_filename)
return model
elif selected_model in model_choices:
filename = model_choices[selected_model]
model_file_path = hf_hub_download(repo_id="lkchew/Tire_Defect_Detection", filename=filename)
return tf.keras.models.load_model(model_file_path)
else:
raise ValueError("Invalid model selected!")
except Exception as e:
debug_info = f"Error loading model: {e}"
return debug_info
# Function to preprocess the image before prediction
def preprocess_image(img, selected_model):
"""Preprocesses the input image for the selected model."""
if selected_model =="Fine Tuning on VGG16 by Edwin":
img = img.resize((128,128))
else:
img = img.resize((224, 224)) # Resize to match model input size
img_array = image.img_to_array(img) # Convert to array
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
# Apply model-specific preprocessing
if selected_model == "Feature Extraction on MobileNetV2 by Lee Kim":
return preprocess_mobilenetv2(img_array)
elif selected_model == "Feature Extraction on EfficientNetB0 by Seng Kuan":
return preprocess_efficientnet(img_array)
elif selected_model == "Fine Tuning on VGG16 by Edwin":
return preprocess_input(img_array)
return img_array # Default return
# Prediction function
def predict_image(img, selected_model):
"""Runs prediction on the input image using the selected model."""
if img is None:
return "Error: No image provided."
try:
debug_info = f"{selected_model}"
model = load_model(selected_model) # Load model dynamically
processed_img = preprocess_image(img, selected_model) # Preprocess input
prediction = model.predict(processed_img) # Run inference
# Determine classification result
result = class_names[1] if prediction[0][0] >= threshold else class_names[0]
confidence = prediction[0][0] if prediction[0][0] >= threshold else 1 - prediction[0][0]
prediction_text = f"Model: {selected_model}\n" \
f"Prediction: {result}\n" \
f"Prediction Score: {prediction[0][0]:.4f}\n" \
f"Confidence: {confidence:.4f}"
#print(prediction_text)
return prediction_text
except Exception as e:
print(f"Prediction error: {e}")
#return "Error: Prediction failed."
return {e}
# Gradio UI with dynamic image selection
with gr.Blocks() as demo:
gr.Markdown("""
# Tyre Detection Model
### Classification Threshold:
- A tyre is classified as **Good** if the prediction score is **≥ 0.5**.
- A tyre is classified as **Defective** if the prediction score is **< 0.5**.
""")
with gr.Row():
test_image_dropdown = gr.Dropdown(
choices=list(test_image.keys()),
label="Select Test Image"
)
image_input = gr.Image(type="pil", label="Selected Image")
model_dropdown = gr.Dropdown(choices=list(model_choices.keys()), label="Select Model")
# When the dropdown value changes, update the image
test_image_dropdown.change(
fn=lambda img_name: load_selected_image(test_image[img_name]) if img_name else None,
inputs=[test_image_dropdown],
outputs=[image_input]
)
# Prediction button
predict_button = gr.Button("Predict")
# Output text
output_text = gr.Textbox(label="Prediction Result")
# When clicking the button, call `predict_image`
predict_button.click(
fn=predict_image,
inputs=[image_input, model_dropdown],
outputs=[output_text]
)
demo.launch()