Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as preprocess_vgg | |
| from tensorflow.keras.applications.efficientnet import EfficientNetB0, preprocess_input as preprocess_efficientnet | |
| from tensorflow.keras.preprocessing import image | |
| import numpy as np | |
| # Define input image size (must match training settings) | |
| image_size = (224, 224) | |
| # Function to build the VGG16 model | |
| def build_vgg16_model(): | |
| base_model = VGG16(weights=None, include_top=False, input_shape=image_size + (3,)) | |
| inputs = keras.layers.Input(shape=image_size + (3,)) | |
| x1 = preprocess_vgg(inputs) | |
| x1 = base_model(x1, training=False) | |
| x1 = keras.layers.Flatten()(x1) # Flatten instead of GlobalAveragePooling2D | |
| x1 = keras.layers.Dropout(rate=0.5)(x1) | |
| x1 = keras.layers.Dense(units=256, activation="relu")(x1) | |
| x1 = keras.layers.Dropout(rate=0.5)(x1) | |
| outputs = keras.layers.Dense(units=1, activation="sigmoid")(x1) | |
| model = keras.models.Model(inputs=[inputs], outputs=[outputs]) | |
| # Compile the model | |
| model.compile( | |
| loss="binary_crossentropy", | |
| optimizer=keras.optimizers.Adam(learning_rate=0.001), | |
| metrics=["accuracy"], | |
| ) | |
| return model | |
| # Function to build the EfficientNetB0 model | |
| def build_efficientnet_model(): | |
| base_model = EfficientNetB0(input_shape=image_size + (3,), include_top=False, weights="imagenet") | |
| inputs = keras.layers.Input(shape=image_size + (3,)) | |
| x = preprocess_efficientnet(inputs) # EfficientNet includes its own preprocessing | |
| x = base_model(x, training=False) | |
| x = keras.layers.GlobalAveragePooling2D()(x) | |
| x = keras.layers.Dropout(rate=0.5)(x) | |
| x = keras.layers.Dense(units=256, activation="relu")(x) | |
| x = keras.layers.Dropout(rate=0.5)(x) | |
| outputs = keras.layers.Dense(units=1, activation="sigmoid")(x) | |
| model = keras.models.Model(inputs=[inputs], outputs=[outputs]) | |
| # Compile the model | |
| model.compile( | |
| loss="binary_crossentropy", | |
| optimizer=keras.optimizers.Adam(learning_rate=0.001), | |
| metrics=["accuracy"], | |
| ) | |
| return model | |
| # Dictionary to store models | |
| models = { | |
| "VGG16": build_vgg16_model(), | |
| "EfficientNetB0": build_efficientnet_model(), | |
| } | |
| # Load pre-trained weights | |
| models["VGG16"].load_weights("VGG16_best_finetune_checkpoint.weights.h5") | |
| models["EfficientNetB0"].load_weights("EfficientNetB0_best_finetune_checkpoint.weights.h5") | |
| print("Models and weights loaded successfully!") | |
| # Set the default model | |
| current_model = models["VGG16"] | |
| # Function to update the current model based on selection | |
| def load_selected_model(model_name): | |
| global current_model | |
| current_model = models[model_name] | |
| return f"Loaded {model_name} model." | |
| # Preprocessing function | |
| def preprocess_img(img, model_name): | |
| img = img.resize(image_size) # Resize to match training | |
| img_array = image.img_to_array(img) | |
| # Apply the correct preprocessing function based on model selection | |
| if model_name == "VGG16": | |
| img_array = preprocess_vgg(img_array) | |
| else: | |
| img_array = preprocess_efficientnet(img_array) | |
| img_array = np.expand_dims(img_array, axis=0) # Add batch dimension | |
| return img_array | |
| # Prediction function | |
| def predict(img, model_name): | |
| img_array = preprocess_img(img, model_name) | |
| predictions = current_model.predict(img_array) | |
| confidence = float(predictions[0][0]) # Convert to float | |
| # Class labels | |
| class_labels = {0: "Defective", 1: "Good"} | |
| predicted_class = 1 if confidence > 0.5 else 0 | |
| result_text = ( | |
| f"Predicted Class: {class_labels[predicted_class]}\n" | |
| f"Confidence (Good): {confidence:.8f}\n" | |
| f"Confidence (Defective): {1 - confidence:.8f}" | |
| ) | |
| return result_text | |
| # Function to clear input and output | |
| def clear(): | |
| return None, "" | |
| # Gradio Interface with Model Selection | |
| with gr.Blocks() as interface: | |
| gr.Markdown( | |
| "## Fine-tuned Defect Tyre Classification\n\n" | |
| "This Gradio-based application allows users to classify tyre images as **'Good'** or **'Defective'** using a " | |
| "fine-tuned deep learning model. Users can select between two models (**VGG16** and **EfficientNetB0**) " | |
| "via a dropdown menu. The selected model is dynamically loaded and applied to the uploaded image, " | |
| "with predictions and confidence scores displayed as output.\n\n" | |
| "Upload a tyre image and select a model (**VGG16** or **EfficientNetB0**) to classify defects." | |
| ) | |
| # Dropdown to select model | |
| model_dropdown = gr.Dropdown( | |
| choices=list(models.keys()), value="VGG16", label="Select Model" | |
| ) | |
| model_status = gr.Textbox(label="Model Status", interactive=False) | |
| with gr.Row(): | |
| input_image = gr.Image(type="pil", label="Upload Image") | |
| output_text = gr.Textbox(label="Classification Output", interactive=False) | |
| with gr.Row(): | |
| classify_button = gr.Button("Classify") | |
| clear_button = gr.Button("Clear") | |
| # Link dropdown to model selection function | |
| model_dropdown.change(fn=load_selected_model, inputs=model_dropdown, outputs=model_status) | |
| # Link buttons to functions | |
| classify_button.click(fn=predict, inputs=[input_image, model_dropdown], outputs=output_text) | |
| clear_button.click(fn=clear, inputs=[], outputs=[input_image, output_text]) | |
| # Run the app | |
| if __name__ == "__main__": | |
| interface.launch() | |