Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import gdown | |
| from PIL import Image | |
| import pillow_avif | |
| input_shape = (32, 32, 3) | |
| resized_shape = (224, 224, 3) | |
| num_classes = 10 | |
| labels = { | |
| 0: "plane", | |
| 1: "car", | |
| 2: "bird", | |
| 3: "cat", | |
| 4: "deer", | |
| 5: "dog", | |
| 6: "frog", | |
| 7: "horse", | |
| 8: "ship", | |
| 9: "truck", | |
| } | |
| # Download the model file | |
| def download_model(): | |
| url = "https://drive.google.com/uc?id=12700bE-pomYKoVQ214VrpBoJ7akXcTpL" | |
| output = "modelV2Lmixed.keras" | |
| gdown.download(url, output, quiet=False) | |
| return output | |
| model_file = download_model() | |
| # Load the model | |
| model = tf.keras.models.load_model(model_file) | |
| # Perform image classification for single class output | |
| # def predict_class(image): | |
| # img = tf.cast(image, tf.float32) | |
| # img = tf.image.resize(img, [input_shape[0], input_shape[1]]) | |
| # img = tf.expand_dims(img, axis=0) | |
| # prediction = model.predict(img) | |
| # class_index = tf.argmax(prediction[0]).numpy() | |
| # predicted_class = labels[class_index] | |
| # return predicted_class | |
| # Perform image classification for multy class output | |
| def predict_class(image): | |
| img = tf.cast(image, tf.float32) | |
| img = tf.image.resize(img, [input_shape[0], input_shape[1]]) | |
| img = tf.expand_dims(img, axis=0) | |
| prediction = model.predict(img) | |
| return prediction[0] | |
| # UI Design for single class output | |
| # def classify_image(image): | |
| # predicted_class = predict_class(image) | |
| # output = f"<h2>Predicted Class: <span style='text-transform:uppercase';>{predicted_class}</span></h2>" | |
| # return output | |
| # UI Design for multy class output | |
| def classify_image(image): | |
| results = predict_class(image) | |
| print("results is ...", results) | |
| output = {labels.get(i): float(results[i]) for i in range(len(results))} | |
| print("output is ...", output) | |
| result = output if max(output.values()) >=0.98 else {"NO_CIFAR10_CLASS": 1} | |
| return result | |
| inputs = gr.components.Image(type="pil", label="Upload an image") | |
| # outputs = gr.outputs.HTML() #uncomment for single class output | |
| outputs = gr.components.Label(num_top_classes=4) | |
| title = "<h1 style='text-align: center;'>Image Classifier</h1>" | |
| description = "Upload an image and get the predicted class." | |
| # css_code='body{background-image:url("file=wave.mp4");}' | |
| gr.Interface(fn=classify_image, | |
| inputs=inputs, | |
| outputs=outputs, | |
| title=title, | |
| examples=[["00_plane.jpg"], ["01_car.jpg"], ["02_house.jpg"], ["03_cat.jpg"], ["04_deer.jpg"]], | |
| # css=css_code, | |
| description=description).launch() | |