import os import random import gradio as gr import matplotlib.pyplot as plt import numpy as np import onnxruntime as ort import polars as pl import seaborn as sns from huggingface_hub import hf_hub_download from PIL import Image # Seaborn configuration sns.set_theme() # Matplotlib configuration plt.rc('font', size=16) plt.rc('axes', titlesize=24) plt.rc('xtick', labelsize=20) plt.rc('ytick', labelsize=20) # Load the model def load_model(): hf_hub_download("ScribbleItAI/efficientnet-b0", token=os.getenv("HF_TOKEN"), filename="efficientnet_b0.onnx", local_dir="assets") return ort.InferenceSession("assets/efficientnet_b0.onnx") # Load vocabulary def load_vocabulary(): hf_hub_download("ScribbleItAI/efficientnet-b0", token=os.getenv("HF_TOKEN"), filename="vocabulary.csv", local_dir="assets") return pl.read_csv("assets/vocabulary.csv").sort("label_idx") def compute_word_weights(vocabulary): train_counts = vocabulary.group_by("word").agg(pl.col("train_count").sum()) total_train_count = train_counts["train_count"].sum() word_weights = [(vocab["word"], vocab["train_count"] / total_train_count) for vocab in vocabulary.rows(named=True)] return [word for word, _ in word_weights], [weight for _, weight in word_weights] ort_session = load_model() input_name = ort_session.get_inputs()[0].name vocabulary = load_vocabulary() words, weights = compute_word_weights(vocabulary) def get_random_word(): return random.choices(words, weights=weights)[0] # Process the image drawn on canvas def process_image(image, current_word): input_img = image["composite"].resize((224, 224)) indices, preds = ort_session.run(None, {input_name: np.array(input_img)}) predictions = [] for pred, idx in zip(preds, indices): vocab = vocabulary.row(idx, named=True) predictions.append({"word": vocab["word"], "category": vocab["category_name"], "prob": pred}) predictions_df = pl.DataFrame(predictions) predictions_df = predictions_df.group_by("word").agg(pl.col("prob").max().alias("prob")) predictions_df = predictions_df.sort("prob", descending=True).head(10) # Visualizing predictions fig, ax = plt.subplots(figsize=(10, 8)) plt.subplots_adjust(left=0.25, top=0.9, right=0.9, bottom=0.1) colors = [ sns.color_palette()[2] if word == current_word else sns.color_palette()[0] for word in predictions_df["word"] ] if current_word == predictions_df["word"][0]: gr.Info("AI guessed the word correctly! 🎉") sns.barplot( data=predictions_df, y="word", x="prob", hue="word", orient="h", palette=colors, legend=False, ax=ax, ) ax.set_title("Top 10 Predictions", pad=15) ax.set_xlabel("Probability") ax.set_ylabel(None) return fig, current_word def create_initial_image(): data = np.full((520, 700, 3), 255, dtype=np.uint8) # White image return Image.fromarray(data) # Interface definition with gr.Blocks(theme=gr.themes.Soft(), css="input {font-size: 24px; font-weight: 600;}") as demo_app: gr.Markdown("# Scribble It! AI Demo 🎨") gr.Markdown("### Draw the word shown below and let the AI guess what it is!") with gr.Row(): word_output = gr.Textbox(label="Your word to draw:", value=get_random_word, interactive=False, scale=1, max_lines=1) new_word_button = gr.Button("New Word", scale=0, variant="primary") with gr.Row(): image_editor = gr.ImageEditor(label="Draw Here!", image_mode="RGB", type="pil", sources=[], transforms=[], layers=False, value=create_initial_image, brush=gr.Brush(colors=["#000000", "#4c72b0", "#55a868", "#d62728"], default_size=5)) plot_output = gr.Plot(label="Model Guesses") image_editor.clear(create_initial_image, outputs=image_editor) image_editor.change(process_image, inputs=[image_editor, word_output], outputs=[plot_output, word_output]) new_word_button.click(get_random_word, outputs=word_output) new_word_button.click(create_initial_image, outputs=image_editor) demo_app.launch()