Spaces:
Sleeping
Sleeping
File size: 4,709 Bytes
69c5e5e a0c5ceb 970bed6 69c5e5e 399e690 69c5e5e 7d3432f 69c5e5e a0c5ceb 69c5e5e 8b26de9 69c5e5e 8b26de9 69c5e5e a0c5ceb 69c5e5e d9bafc3 970bed6 69c5e5e 8b26de9 970bed6 69c5e5e d9bafc3 a0c5ceb d9bafc3 a0c5ceb 8b26de9 970bed6 69c5e5e 970bed6 d9bafc3 8b26de9 a0c5ceb 8b26de9 7d3432f a0c5ceb 69c5e5e 8b26de9 a0c5ceb 69c5e5e 628ad7b 8b26de9 628ad7b 7d3432f a0c5ceb 7d3432f a0c5ceb 7d3432f a0c5ceb 8b26de9 a0c5ceb 7d3432f a0c5ceb d9bafc3 a0c5ceb 8b26de9 a0c5ceb 8b26de9 a0c5ceb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import random
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
import polars as pl
import seaborn as sns
from huggingface_hub import hf_hub_download
from PIL import Image
# Seaborn configuration
sns.set_theme()
# Matplotlib configuration
plt.rc('font', size=16)
plt.rc('axes', titlesize=24)
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)
# Load the model
def load_model():
hf_hub_download("ScribbleItAI/efficientnet-b0",
token=os.getenv("HF_TOKEN"),
filename="efficientnet_b0.onnx",
local_dir="assets")
return ort.InferenceSession("assets/efficientnet_b0.onnx")
# Load vocabulary
def load_vocabulary():
hf_hub_download("ScribbleItAI/efficientnet-b0",
token=os.getenv("HF_TOKEN"),
filename="vocabulary.csv",
local_dir="assets")
return pl.read_csv("assets/vocabulary.csv").sort("label_idx")
def compute_word_weights(vocabulary):
train_counts = vocabulary.group_by("word").agg(pl.col("train_count").sum())
total_train_count = train_counts["train_count"].sum()
word_weights = [(vocab["word"], vocab["train_count"] / total_train_count)
for vocab in vocabulary.rows(named=True)]
return [word for word, _ in word_weights], [weight for _, weight in word_weights]
ort_session = load_model()
input_name = ort_session.get_inputs()[0].name
vocabulary = load_vocabulary()
words, weights = compute_word_weights(vocabulary)
def get_random_word():
return random.choices(words, weights=weights)[0]
# Process the image drawn on canvas
def process_image(image, current_word):
input_img = image["composite"].resize((224, 224))
indices, preds = ort_session.run(None, {input_name: np.array(input_img)})
predictions = []
for pred, idx in zip(preds, indices):
vocab = vocabulary.row(idx, named=True)
predictions.append({"word": vocab["word"], "category": vocab["category_name"], "prob": pred})
predictions_df = pl.DataFrame(predictions)
predictions_df = predictions_df.group_by("word").agg(pl.col("prob").max().alias("prob"))
predictions_df = predictions_df.sort("prob", descending=True).head(10)
# Visualizing predictions
fig, ax = plt.subplots(figsize=(10, 8))
plt.subplots_adjust(left=0.25, top=0.9, right=0.9, bottom=0.1)
colors = [
sns.color_palette()[2] if word == current_word else sns.color_palette()[0]
for word in predictions_df["word"]
]
if current_word == predictions_df["word"][0]:
gr.Info("AI guessed the word correctly! 🎉")
sns.barplot(
data=predictions_df,
y="word",
x="prob",
hue="word",
orient="h",
palette=colors,
legend=False,
ax=ax,
)
ax.set_title("Top 10 Predictions", pad=15)
ax.set_xlabel("Probability")
ax.set_ylabel(None)
return fig, current_word
def create_initial_image():
data = np.full((520, 700, 3), 255, dtype=np.uint8) # White image
return Image.fromarray(data)
# Interface definition
with gr.Blocks(theme=gr.themes.Soft(), css="input {font-size: 24px; font-weight: 600;}") as demo_app:
gr.Markdown("# Scribble It! AI Demo 🎨")
gr.Markdown("### Draw the word shown below and let the AI guess what it is!")
with gr.Row():
word_output = gr.Textbox(label="Your word to draw:",
value=get_random_word,
interactive=False,
scale=1,
max_lines=1)
new_word_button = gr.Button("New Word", scale=0, variant="primary")
with gr.Row():
image_editor = gr.ImageEditor(label="Draw Here!",
image_mode="RGB",
type="pil",
sources=[],
transforms=[],
layers=False,
value=create_initial_image,
brush=gr.Brush(colors=["#000000", "#4c72b0", "#55a868", "#d62728"],
default_size=5))
plot_output = gr.Plot(label="Model Guesses")
image_editor.clear(create_initial_image, outputs=image_editor)
image_editor.change(process_image, inputs=[image_editor, word_output], outputs=[plot_output, word_output])
new_word_button.click(get_random_word, outputs=word_output)
new_word_button.click(create_initial_image, outputs=image_editor)
demo_app.launch()
|