Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image, ImageOps, ImageStat | |
| from transformers import pipeline | |
| PIPE = pipeline( | |
| task="image-classification", | |
| model="kmewhort/beit-sketch-classifier", | |
| top_k=5, | |
| ) | |
| def preprocess(image: Image.Image): | |
| if image is None: | |
| return None | |
| img = image.convert("L") | |
| # Ensure black strokes on white background | |
| if ImageStat.Stat(img).mean[0] < 128: | |
| img = ImageOps.invert(img) | |
| return img.convert("RGB") | |
| def predict(image: Image.Image): | |
| img = preprocess(image) | |
| if img is None: | |
| return [] | |
| return PIPE(img) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# QuickDraw Sketch Classifier") | |
| inp = gr.Image(type="pil", label="Sketch") | |
| out = gr.JSON(label="Predictions") | |
| btn = gr.Button("Predict") | |
| btn.click(predict, inputs=inp, outputs=out, api_name="predict") | |
| demo.launch() | |