demo / app.py
Philipp Normann
Argsort operation is now integrated into the ONNX model
970bed6
import os
import random
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
import polars as pl
import seaborn as sns
from huggingface_hub import hf_hub_download
from PIL import Image
# Seaborn configuration
sns.set_theme()
# Matplotlib configuration
plt.rc('font', size=16)
plt.rc('axes', titlesize=24)
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)
# Load the model
def load_model():
hf_hub_download("ScribbleItAI/efficientnet-b0",
token=os.getenv("HF_TOKEN"),
filename="efficientnet_b0.onnx",
local_dir="assets")
return ort.InferenceSession("assets/efficientnet_b0.onnx")
# Load vocabulary
def load_vocabulary():
hf_hub_download("ScribbleItAI/efficientnet-b0",
token=os.getenv("HF_TOKEN"),
filename="vocabulary.csv",
local_dir="assets")
return pl.read_csv("assets/vocabulary.csv").sort("label_idx")
def compute_word_weights(vocabulary):
train_counts = vocabulary.group_by("word").agg(pl.col("train_count").sum())
total_train_count = train_counts["train_count"].sum()
word_weights = [(vocab["word"], vocab["train_count"] / total_train_count)
for vocab in vocabulary.rows(named=True)]
return [word for word, _ in word_weights], [weight for _, weight in word_weights]
ort_session = load_model()
input_name = ort_session.get_inputs()[0].name
vocabulary = load_vocabulary()
words, weights = compute_word_weights(vocabulary)
def get_random_word():
return random.choices(words, weights=weights)[0]
# Process the image drawn on canvas
def process_image(image, current_word):
input_img = image["composite"].resize((224, 224))
indices, preds = ort_session.run(None, {input_name: np.array(input_img)})
predictions = []
for pred, idx in zip(preds, indices):
vocab = vocabulary.row(idx, named=True)
predictions.append({"word": vocab["word"], "category": vocab["category_name"], "prob": pred})
predictions_df = pl.DataFrame(predictions)
predictions_df = predictions_df.group_by("word").agg(pl.col("prob").max().alias("prob"))
predictions_df = predictions_df.sort("prob", descending=True).head(10)
# Visualizing predictions
fig, ax = plt.subplots(figsize=(10, 8))
plt.subplots_adjust(left=0.25, top=0.9, right=0.9, bottom=0.1)
colors = [
sns.color_palette()[2] if word == current_word else sns.color_palette()[0]
for word in predictions_df["word"]
]
if current_word == predictions_df["word"][0]:
gr.Info("AI guessed the word correctly! πŸŽ‰")
sns.barplot(
data=predictions_df,
y="word",
x="prob",
hue="word",
orient="h",
palette=colors,
legend=False,
ax=ax,
)
ax.set_title("Top 10 Predictions", pad=15)
ax.set_xlabel("Probability")
ax.set_ylabel(None)
return fig, current_word
def create_initial_image():
data = np.full((520, 700, 3), 255, dtype=np.uint8) # White image
return Image.fromarray(data)
# Interface definition
with gr.Blocks(theme=gr.themes.Soft(), css="input {font-size: 24px; font-weight: 600;}") as demo_app:
gr.Markdown("# Scribble It! AI Demo 🎨")
gr.Markdown("### Draw the word shown below and let the AI guess what it is!")
with gr.Row():
word_output = gr.Textbox(label="Your word to draw:",
value=get_random_word,
interactive=False,
scale=1,
max_lines=1)
new_word_button = gr.Button("New Word", scale=0, variant="primary")
with gr.Row():
image_editor = gr.ImageEditor(label="Draw Here!",
image_mode="RGB",
type="pil",
sources=[],
transforms=[],
layers=False,
value=create_initial_image,
brush=gr.Brush(colors=["#000000", "#4c72b0", "#55a868", "#d62728"],
default_size=5))
plot_output = gr.Plot(label="Model Guesses")
image_editor.clear(create_initial_image, outputs=image_editor)
image_editor.change(process_image, inputs=[image_editor, word_output], outputs=[plot_output, word_output])
new_word_button.click(get_random_word, outputs=word_output)
new_word_button.click(create_initial_image, outputs=image_editor)
demo_app.launch()