Spaces:
Build error
Build error
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input | |
| from tensorflow.keras.preprocessing import image | |
| from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as preprocess_mobilenetv2 | |
| from tensorflow.keras.applications.efficientnet import preprocess_input as preprocess_efficientnet | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| # Define available models | |
| model_choices = { | |
| "Fine Tuning on VGG16 by Edwin": "best_finetune_checkpoint.weights.h5", | |
| "Feature Extraction on EfficientNetB0 by Seng Kuan": "full_model.weights_SK.h5", | |
| "Feature Extraction on MobileNetV2 by Lee Kim": "full_model_mobilenetv2.h5" | |
| } | |
| # Define test images (URLs or file paths) | |
| test_image = { | |
| "Defective tyre image": "https://huggingface.co/spaces/lkchew/2972535Q/resolve/main/test_tire_dataset/Defective%20(454)_SK.jpg", | |
| "Good tyre image": "https://huggingface.co/spaces/lkchew/2972535Q/resolve/main/test_tire_dataset/good%20(160)_SK.jpg" | |
| } | |
| # Class names | |
| class_names = ['Defective Tyre', 'Good Tyre'] | |
| threshold = 0.5 | |
| # Build the model (must match training architecture) | |
| def build_model(): | |
| debug_info = f"into build model \n" | |
| try: | |
| base_model = VGG16(weights=None, include_top=False, input_shape=image_size + (3,)) | |
| inputs = keras.layers.Input(shape=image_size + (3,)) | |
| x1 = preprocess_input(inputs) | |
| x1 = base_model(x1, training=False) | |
| x1 = keras.layers.Flatten()(x1) # Flatten instead of GlobalAveragePooling2D | |
| x1 = keras.layers.Dropout(rate=0.5)(x1) | |
| x1 = keras.layers.Dense(units=256, activation="relu")(x1) | |
| x1 = keras.layers.Dropout(rate=0.5)(x1) | |
| outputs = keras.layers.Dense(units=1, activation="sigmoid")(x1) | |
| model = keras.models.Model(inputs=[inputs], outputs=[outputs]) | |
| # Compile the model (same as training) | |
| model.compile( | |
| loss="binary_crossentropy", | |
| optimizer=keras.optimizers.Adam(learning_rate=0.001), | |
| metrics=["accuracy"], | |
| ) | |
| return model | |
| except Exception as e: | |
| return debug_info | |
| # Function to load the selected image (handle both local and online images) | |
| def load_selected_image(image_path): | |
| """Loads an image from a local path or a URL.""" | |
| try: | |
| if image_path.startswith("http"): # If it's an online image | |
| response = requests.get(image_path) | |
| img = Image.open(BytesIO(response.content)) | |
| else: # Local image | |
| img = Image.open(image_path) | |
| return img | |
| except Exception as e: | |
| debug_info = f"Error loading image: {e}" | |
| return debug_info | |
| # Function to load the selected model | |
| def load_model(selected_model): | |
| """Loads the selected deep learning model.""" | |
| debug_info = f"into load model\n" | |
| try: | |
| if selected_model == "Fine Tuning on VGG16 by Edwin": | |
| # Load the model and weights | |
| model = build_model() | |
| model.load_weights("best_finetune_checkpoint.weights.h5") | |
| # model = build_model() | |
| # debug_info = f"inside vgg16 if loop\n" | |
| # # URL of weights | |
| # weights_url = "https://huggingface.co/spaces/lkchew/2972535Q/resolve/main/best_finetune_checkpoint.weights.h5" | |
| # weights_filename = "best_finetune_checkpoint.weights.h5" | |
| # # Download weights file | |
| # response = requests.get(weights_url) | |
| # if response.status_code != 200: | |
| # debug_info = f"❌ Error: Failed to download weights ({response.status_code}).\n" | |
| # return debug_info | |
| # with open(weights_filename, "wb") as f: | |
| # f.write(response.content) | |
| # # Load weights | |
| # model.load_weights(weights_filename) | |
| return model | |
| elif selected_model in model_choices: | |
| filename = model_choices[selected_model] | |
| model_file_path = hf_hub_download(repo_id="lkchew/Tire_Defect_Detection", filename=filename) | |
| return tf.keras.models.load_model(model_file_path) | |
| else: | |
| raise ValueError("Invalid model selected!") | |
| except Exception as e: | |
| debug_info = f"Error loading model: {e}" | |
| return debug_info | |
| # Function to preprocess the image before prediction | |
| def preprocess_image(img, selected_model): | |
| """Preprocesses the input image for the selected model.""" | |
| if selected_model =="Fine Tuning on VGG16 by Edwin": | |
| img = img.resize((128,128)) | |
| else: | |
| img = img.resize((224, 224)) # Resize to match model input size | |
| img_array = image.img_to_array(img) # Convert to array | |
| img_array = np.expand_dims(img_array, axis=0) # Add batch dimension | |
| # Apply model-specific preprocessing | |
| if selected_model == "Feature Extraction on MobileNetV2 by Lee Kim": | |
| return preprocess_mobilenetv2(img_array) | |
| elif selected_model == "Feature Extraction on EfficientNetB0 by Seng Kuan": | |
| return preprocess_efficientnet(img_array) | |
| elif selected_model == "Fine Tuning on VGG16 by Edwin": | |
| return preprocess_input(img_array) | |
| return img_array # Default return | |
| # Prediction function | |
| def predict_image(img, selected_model): | |
| """Runs prediction on the input image using the selected model.""" | |
| if img is None: | |
| return "Error: No image provided." | |
| try: | |
| debug_info = f"{selected_model}" | |
| model = load_model(selected_model) # Load model dynamically | |
| processed_img = preprocess_image(img, selected_model) # Preprocess input | |
| prediction = model.predict(processed_img) # Run inference | |
| # Determine classification result | |
| result = class_names[1] if prediction[0][0] >= threshold else class_names[0] | |
| confidence = prediction[0][0] if prediction[0][0] >= threshold else 1 - prediction[0][0] | |
| prediction_text = f"Model: {selected_model}\n" \ | |
| f"Prediction: {result}\n" \ | |
| f"Prediction Score: {prediction[0][0]:.4f}\n" \ | |
| f"Confidence: {confidence:.4f}" | |
| #print(prediction_text) | |
| return prediction_text | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| #return "Error: Prediction failed." | |
| return {e} | |
| # Gradio UI with dynamic image selection | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # Tyre Detection Model | |
| ### Classification Threshold: | |
| - A tyre is classified as **Good** if the prediction score is **≥ 0.5**. | |
| - A tyre is classified as **Defective** if the prediction score is **< 0.5**. | |
| """) | |
| with gr.Row(): | |
| test_image_dropdown = gr.Dropdown( | |
| choices=list(test_image.keys()), | |
| label="Select Test Image" | |
| ) | |
| image_input = gr.Image(type="pil", label="Selected Image") | |
| model_dropdown = gr.Dropdown(choices=list(model_choices.keys()), label="Select Model") | |
| # When the dropdown value changes, update the image | |
| test_image_dropdown.change( | |
| fn=lambda img_name: load_selected_image(test_image[img_name]) if img_name else None, | |
| inputs=[test_image_dropdown], | |
| outputs=[image_input] | |
| ) | |
| # Prediction button | |
| predict_button = gr.Button("Predict") | |
| # Output text | |
| output_text = gr.Textbox(label="Prediction Result") | |
| # When clicking the button, call `predict_image` | |
| predict_button.click( | |
| fn=predict_image, | |
| inputs=[image_input, model_dropdown], | |
| outputs=[output_text] | |
| ) | |
| demo.launch() | |