Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| from tensorflow.keras.preprocessing.image import load_img, img_to_array | |
| from tensorflow.keras.applications.imagenet_utils import preprocess_input | |
| import os | |
| # Load your frozen model | |
| model = tf.keras.models.load_model("final_trashnet_transfer_learning_model.keras") | |
| # Mapping of original classes to broader categories | |
| class_mapping = { | |
| 0: "Compostable", # compostable | |
| 1: "Recyclables", # recyclable | |
| 2: "Trash", #trash | |
| } | |
| # Define a function to preprocess the input image | |
| def preprocess_image(image): | |
| # Resize the image to 128*128 (as required by your model) | |
| image = image.resize((128, 128)) | |
| # Convert the image to a NumPy array and normalize it | |
| img_array = img_to_array(image) | |
| img_array = preprocess_input(img_array) | |
| # Ensure the image has the correct shape (32, 32, 3) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array | |
| # Define the prediction function | |
| def classify_trash(image): | |
| processed_image = preprocess_image(image) | |
| predictions = model.predict(processed_image) | |
| print(predictions) | |
| class_index = np.argmax(predictions) | |
| confidence = np.max(predictions) | |
| predicted_class = class_mapping[class_index] | |
| return f"Predicted Category: {predicted_class}", f"Confidence: {confidence*100:.2f}" | |
| # Function to gather example images dynamically | |
| def get_example_images(): | |
| example_images = [] | |
| base_dir = "examples" | |
| categories = ["Compostable", "Recyclables", "Trash"] | |
| for category in categories: | |
| folder_path = os.path.join(base_dir, category) | |
| if os.path.exists(folder_path): | |
| example_images += [ | |
| os.path.join(folder_path, img) for img in os.listdir(folder_path) if img.endswith((".jpg", ".png")) | |
| ] | |
| return example_images | |
| # Example images from all categories | |
| example_images = get_example_images() | |
| # Define the Gradio interface | |
| interface = gr.Interface( | |
| fn=classify_trash, | |
| inputs=gr.Image(type="pil", label="Upload an Image"), | |
| outputs=[gr.Textbox(label="Predicted Category"), gr.Textbox(label="Confidence")], | |
| examples= example_images, | |
| title="Trash Classifier", | |
| description="Upload an image of trash, and the model will classify it into 'Compostable', 'Recyclables' and 'Trash' based on its category." | |
| ) | |
| # Run the app | |
| if __name__ == "__main__": | |
| interface.launch() | |