import gradio as gr import tensorflow as tf import numpy as np from datasets import load_dataset from PIL import Image # Define the preprocessing function def preprocess_inference_image(image): # Ensure image is a tensor image = tf.convert_to_tensor(image, dtype=tf.float32) # Convert image to RGB if it has 4 channels (RGBA) if tf.shape(image)[-1] == 4: image = image[..., :3] # Select first 3 channels (RGB) # Normalize to [0,1] image = tf.image.convert_image_dtype(image, tf.float32) # Resize to MobileNetV3 input size image = tf.image.resize(image, [224, 224]) # Add batch dimension image = tf.expand_dims(image, axis=0) return image # Load the trained model model = tf.keras.models.load_model('maize_disease_model.keras') # Load the dataset to get label names ds = load_dataset("aquib1011/maize-leaf-disease", cache_dir=None) label_names = ds['train'].features['label'].names def predict_maize_disease(image): # Convert PIL Image to numpy array image = np.array(image) # Apply preprocessing to the input image processed_image = preprocess_inference_image(image) # Make a prediction predictions = model.predict(processed_image) # Return the results as a dictionary for Gradio's Label component return {label_names[i]: float(predictions[0][i]) for i in range(len(label_names))} # Create the Gradio interface iface = gr.Interface( fn=predict_maize_disease, inputs=gr.Image(type="pil"), outputs=gr.Label(), title="Maize Leaf Disease Classifier", description="Upload an image of a maize leaf to get a prediction of the disease." ) if __name__ == "__main__": iface.launch()