Spaces:
Build error
Build error
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import tensorflow as tf | |
| # Load the model | |
| vgg16_model = tf.keras.models.load_model('fish_vgg16_model.h5') | |
| mobilenet_model = tf.keras.models.load_model('fish_mobilenet_model.h5') | |
| custom_cnn_model = tf.keras.models.load_model('fish_cnn_model.h5') | |
| # Define the preprocessing function | |
| def preprocess_image(image): | |
| image = cv2.resize(image, (224, 224)) | |
| image = image / 255.0 | |
| return image | |
| # Define the prediction function | |
| def predict(image): | |
| # Preprocess the image | |
| processed_image = preprocess_image(image) | |
| # Make predictions using the models | |
| vgg_pred = vgg16_model.predict(np.expand_dims(processed_image, axis=0))[0] | |
| mobilenet_pred = mobilenet_model.predict(np.expand_dims(processed_image, axis=0))[0] | |
| custom_cnn_pred = custom_cnn_model.predict(np.expand_dims(processed_image, axis=0))[0] | |
| # Get the predicted labels | |
| vgg_label = np.argmax(vgg_pred) | |
| mobilenet_label = np.argmax(mobilenet_pred) | |
| custom_cnn_label = np.argmax(custom_cnn_pred) | |
| label_map = {0: 'Black Sea Sprat', 1: 'Gilt-Head Bream', 2: 'Hourse Mackerel', | |
| 3: 'Red Mullet', 4: 'Red Sea Bream', 5: 'Sea Bass', | |
| 6: 'Shrimp', 7: 'Striped Red Mullet', 8: 'Trout'} | |
| label1 = label_map[custom_cnn_label] | |
| label2 = label_map[vgg_label] | |
| label3 = label_map[mobilenet_label] | |
| return label1, label2, label3 | |
| # Create the Gradio interface | |
| inputs = gr.components.Image() | |
| outputs=[ | |
| gr.components.Textbox(label="Custom CNN Model Label"), | |
| gr.components.Textbox(label="VGG16 Model Label"), | |
| gr.components.Textbox(label="MobileNet Label") | |
| ] | |
| gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title="Fish Classification", theme="dark").launch() |