celine-li commited on
Commit
0b1e701
·
verified ·
1 Parent(s): 053c5b2

Upload 2 files

Browse files

init quickdraw app

Files changed (2) hide show
  1. app.py +36 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageOps, ImageStat
3
+ from transformers import pipeline
4
+
5
+ PIPE = pipeline(
6
+ task="image-classification",
7
+ model="kmewhort/beit-sketch-classifier",
8
+ top_k=5,
9
+ )
10
+
11
+
12
+ def preprocess(image: Image.Image):
13
+ if image is None:
14
+ return None
15
+ img = image.convert("L")
16
+ # Ensure black strokes on white background
17
+ if ImageStat.Stat(img).mean[0] < 128:
18
+ img = ImageOps.invert(img)
19
+ return img.convert("RGB")
20
+
21
+
22
+ def predict(image: Image.Image):
23
+ img = preprocess(image)
24
+ if img is None:
25
+ return []
26
+ return PIPE(img)
27
+
28
+
29
+ with gr.Blocks() as demo:
30
+ gr.Markdown("# QuickDraw Sketch Classifier")
31
+ inp = gr.Image(type="pil", label="Sketch")
32
+ out = gr.JSON(label="Predictions")
33
+ btn = gr.Button("Predict")
34
+ btn.click(predict, inputs=inp, outputs=out, api_name="predict")
35
+
36
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ pillow