| import gradio as gr | |
| from huggingface_hub import from_pretrained_keras | |
| import tensorflow as tf | |
| CLASSES = { | |
| 0: "airplane", | |
| 1: "automobile", | |
| 2: "bird", | |
| 3: "cat", | |
| 4: "deer", | |
| 5: "dog", | |
| 6: "frog", | |
| 7: "horse", | |
| 8: "ship", | |
| 9: "truck", | |
| } | |
| IMAGE_SIZE = 32 | |
| model = from_pretrained_keras("EdoAbati/cct") | |
| def reshape_image(image): | |
| image = tf.convert_to_tensor(image) | |
| image.set_shape([None, None, 3]) | |
| image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE]) | |
| image = tf.expand_dims(image, axis=0) | |
| return image | |
| def classify_image(input_image): | |
| input_image = reshape_image(input_image) | |
| logits = model.predict(input_image).flatten() | |
| predictions = tf.nn.softmax(logits) | |
| output_labels = {CLASSES[i]: float(predictions[i]) for i in CLASSES.keys()} | |
| return output_labels | |
| examples = [["./bird.png"], ["./cat.png"], ["./dog.png"], ["./horse.png"]] | |
| title = "Image Classification using Compact Convolutional Transformer (CCT)" | |
| description = """ | |
| Upload an image or select one from the examples and ask the model to label it! | |
| <br /> | |
| The model was trained on the <a href="https://www.cs.toronto.edu/~kriz/cifar.html" target="_blank">CIFAR-10 dataset</a>. Therefore, it is able to recognise these 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck. | |
| <br /> | |
| <br /> | |
| <p> | |
| <b>Model: https://huggingface.co/keras-io/cct</b> | |
| <br /> | |
| <b>Keras Example: https://keras.io/examples/vision/cct/</b> | |
| </p> | |
| <br /> | |
| """ | |
| article = """ | |
| <div style="text-align: center;"> | |
| Space by <a href="https://www.linkedin.com/in/edoardoabati/" target="_blank">Edoardo Abati</a> | |
| <br /> | |
| Keras example by <a href="https://twitter.com/RisingSayak" target="_blank">Sayak Paul</a> | |
| </div> | |
| """ | |
| interface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.inputs.Image(), | |
| outputs=gr.outputs.Label(), | |
| examples=examples, | |
| title=title, | |
| description=description, | |
| article=article, | |
| allow_flagging="never", | |
| ) | |
| interface.launch(enable_queue=True) | |