Spaces:
Build error
Build error
| import tensorflow as tf | |
| from tensorflow.keras.preprocessing import image | |
| from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as preprocess_mobilenetv2 | |
| from tensorflow.keras.applications.efficientnet import preprocess_input as preprocess_efficientnet | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| model_choices = { | |
| "EfficientNetB0": "full_model.weights_efficientnetb0.h5", | |
| "MobileNetV2": "full_model (1).h5" | |
| } | |
| # test_image = { | |
| # "Defective tyre image 1": "lkchew/2972535Q/test_tire_dataset/Defective%20(1)_resized.jpg", | |
| # "Defective tyre image 2": "lkchew/2972535Q/test_tire_dataset/Defective%20(1)_resized.jpg", | |
| # "Good tyre image 1": "lkchew/2972535Q/test_tire_dataset/good%20(40)_resized.jpg", | |
| # "Good tyre image 2": "lkchew/2972535Q/test_tire_dataset/good%20(40)_resized.jpg" | |
| # } | |
| class_names = ['Defective Tyre', 'Good Tyre'] | |
| threshold = 0.5 | |
| # def load_selected_image(image_path): | |
| # """Loads the selected image from the dropdown.""" | |
| # return Image.open(image_path) if image_path else None | |
| def load_model(selected_model): | |
| if selected_model in model_choices: | |
| filename = model_choices[selected_model] # Get filename dynamically | |
| model_file_path = hf_hub_download(repo_id="lkchew/Tire_Defect_Detection", filename=filename) | |
| model = tf.keras.models.load_model(model_file_path) | |
| else: | |
| raise ValueError("Invalid model selected!") | |
| return model | |
| # Preprocessing the image | |
| def preprocess_image(img, selected_model): | |
| img = img.resize((224, 224)) # Resize image to 224x224 | |
| img_array = image.img_to_array(img) # Convert image to array | |
| img_array = np.expand_dims(img_array, axis=0) # Add batch dimension | |
| # Select appropriate preprocessing based on the model selected | |
| if selected_model == "MobileNetV2": | |
| return preprocess_mobilenetv2(img_array) | |
| elif selected_model == "EfficientNetB0": | |
| return preprocess_efficientnet(img_array) | |
| # Prediction function | |
| def predict_image(img, selected_model): | |
| model = load_model(selected_model) # Dynamically load the selected model | |
| processed_img = preprocess_image(img, selected_model) # Preprocess image based on selected model | |
| prediction = model.predict(processed_img) # Make the prediction | |
| # Determine the result based on prediction threshold | |
| result = class_names[1] if prediction[0][0] >= threshold else class_names[0] | |
| confidence = prediction[0][0] if prediction[0][0] >= threshold else 1 - prediction[0][0] | |
| prediction_text = ( | |
| f"Model: {selected_model},\n" | |
| f"Prediction: {result}, Prediction Score: {prediction[0][0]:.4f},\n" | |
| f"Confidence: {confidence:.4f}" | |
| ) | |
| #print(prediction_text) | |
| return prediction_text | |
| #Gradio UI with model selection | |
| # gr.Markdown("Tyre Detection Model using Feature Extraction") | |
| # gr_interface = gr.Interface( | |
| # fn=predict_image, | |
| # inputs=[ | |
| # gr.Image(type="pil"), # Dropdown for selecting images | |
| # gr.Dropdown(choices=list(model_choices.keys()), label="Select Model") # Model selection | |
| # ], | |
| # outputs="text" | |
| # ) | |
| # gr_interface.launch(share=True) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # Tyre Defect Detection Model by Chew Lee Kim | |
| ### Binary Classification Criteria: | |
| - A tyre is classified as **Good** if the prediction score is **≥ 0.5**. | |
| - A tyre is classified as **Defective** if the prediction score is **< 0.5**. | |
| """) | |
| image_input = gr.Image(type="pil", label="Selected Image") | |
| model_dropdown = gr.Dropdown(choices=list(model_choices.keys()), label="Select Model") | |
| predict_button = gr.Button("Predict") | |
| output_text = gr.Textbox(label="Prediction Result") | |
| predict_button.click( | |
| fn=predict_image, | |
| inputs=[image_input, model_dropdown], | |
| outputs=output_text | |
| ) | |
| demo.launch() | |