import gradio as gr from PIL import Image, ImageOps, ImageStat from transformers import pipeline PIPE = pipeline( task="image-classification", model="kmewhort/beit-sketch-classifier", top_k=5, ) def preprocess(image: Image.Image): if image is None: return None img = image.convert("L") # Ensure black strokes on white background if ImageStat.Stat(img).mean[0] < 128: img = ImageOps.invert(img) return img.convert("RGB") def predict(image: Image.Image): img = preprocess(image) if img is None: return [] return PIPE(img) with gr.Blocks() as demo: gr.Markdown("# QuickDraw Sketch Classifier") inp = gr.Image(type="pil", label="Sketch") out = gr.JSON(label="Predictions") btn = gr.Button("Predict") btn.click(predict, inputs=inp, outputs=out, api_name="predict") demo.launch()