quickdraw-guess / app.py
celine-li's picture
Upload 2 files
0b1e701 verified
raw
history blame contribute delete
870 Bytes
import gradio as gr
from PIL import Image, ImageOps, ImageStat
from transformers import pipeline
PIPE = pipeline(
task="image-classification",
model="kmewhort/beit-sketch-classifier",
top_k=5,
)
def preprocess(image: Image.Image):
if image is None:
return None
img = image.convert("L")
# Ensure black strokes on white background
if ImageStat.Stat(img).mean[0] < 128:
img = ImageOps.invert(img)
return img.convert("RGB")
def predict(image: Image.Image):
img = preprocess(image)
if img is None:
return []
return PIPE(img)
with gr.Blocks() as demo:
gr.Markdown("# QuickDraw Sketch Classifier")
inp = gr.Image(type="pil", label="Sketch")
out = gr.JSON(label="Predictions")
btn = gr.Button("Predict")
btn.click(predict, inputs=inp, outputs=out, api_name="predict")
demo.launch()