Gopikanth123's picture
Update app.py
231e0cd verified
import gradio as gr
import numpy as np
import cv2
import tensorflow as tf
# Load the model
vgg16_model = tf.keras.models.load_model('fish_vgg16_model.h5')
mobilenet_model = tf.keras.models.load_model('fish_mobilenet_model.h5')
custom_cnn_model = tf.keras.models.load_model('fish_cnn_model.h5')
# Define the preprocessing function
def preprocess_image(image):
image = cv2.resize(image, (224, 224))
image = image / 255.0
return image
# Define the prediction function
def predict(image):
# Preprocess the image
processed_image = preprocess_image(image)
# Make predictions using the models
vgg_pred = vgg16_model.predict(np.expand_dims(processed_image, axis=0))[0]
mobilenet_pred = mobilenet_model.predict(np.expand_dims(processed_image, axis=0))[0]
custom_cnn_pred = custom_cnn_model.predict(np.expand_dims(processed_image, axis=0))[0]
# Get the predicted labels
vgg_label = np.argmax(vgg_pred)
mobilenet_label = np.argmax(mobilenet_pred)
custom_cnn_label = np.argmax(custom_cnn_pred)
label_map = {0: 'Black Sea Sprat', 1: 'Gilt-Head Bream', 2: 'Hourse Mackerel',
3: 'Red Mullet', 4: 'Red Sea Bream', 5: 'Sea Bass',
6: 'Shrimp', 7: 'Striped Red Mullet', 8: 'Trout'}
label1 = label_map[custom_cnn_label]
label2 = label_map[vgg_label]
label3 = label_map[mobilenet_label]
return label1, label2, label3
# Create the Gradio interface
inputs = gr.components.Image()
outputs=[
gr.components.Textbox(label="Custom CNN Model Label"),
gr.components.Textbox(label="VGG16 Model Label"),
gr.components.Textbox(label="MobileNet Label")
]
gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title="Fish Classification", theme="dark").launch()