Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import onnxruntime as ort | |
| import polars as pl | |
| import seaborn as sns | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| # Seaborn configuration | |
| sns.set_theme() | |
| # Matplotlib configuration | |
| plt.rc('font', size=16) | |
| plt.rc('axes', titlesize=24) | |
| plt.rc('xtick', labelsize=20) | |
| plt.rc('ytick', labelsize=20) | |
| # Load the model | |
| def load_model(): | |
| hf_hub_download("ScribbleItAI/efficientnet-b0", | |
| token=os.getenv("HF_TOKEN"), | |
| filename="efficientnet_b0.onnx", | |
| local_dir="assets") | |
| return ort.InferenceSession("assets/efficientnet_b0.onnx") | |
| # Load vocabulary | |
| def load_vocabulary(): | |
| hf_hub_download("ScribbleItAI/efficientnet-b0", | |
| token=os.getenv("HF_TOKEN"), | |
| filename="vocabulary.csv", | |
| local_dir="assets") | |
| return pl.read_csv("assets/vocabulary.csv").sort("label_idx") | |
| def compute_word_weights(vocabulary): | |
| train_counts = vocabulary.group_by("word").agg(pl.col("train_count").sum()) | |
| total_train_count = train_counts["train_count"].sum() | |
| word_weights = [(vocab["word"], vocab["train_count"] / total_train_count) | |
| for vocab in vocabulary.rows(named=True)] | |
| return [word for word, _ in word_weights], [weight for _, weight in word_weights] | |
| ort_session = load_model() | |
| input_name = ort_session.get_inputs()[0].name | |
| vocabulary = load_vocabulary() | |
| words, weights = compute_word_weights(vocabulary) | |
| def get_random_word(): | |
| return random.choices(words, weights=weights)[0] | |
| # Process the image drawn on canvas | |
| def process_image(image, current_word): | |
| input_img = image["composite"].resize((224, 224)) | |
| indices, preds = ort_session.run(None, {input_name: np.array(input_img)}) | |
| predictions = [] | |
| for pred, idx in zip(preds, indices): | |
| vocab = vocabulary.row(idx, named=True) | |
| predictions.append({"word": vocab["word"], "category": vocab["category_name"], "prob": pred}) | |
| predictions_df = pl.DataFrame(predictions) | |
| predictions_df = predictions_df.group_by("word").agg(pl.col("prob").max().alias("prob")) | |
| predictions_df = predictions_df.sort("prob", descending=True).head(10) | |
| # Visualizing predictions | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| plt.subplots_adjust(left=0.25, top=0.9, right=0.9, bottom=0.1) | |
| colors = [ | |
| sns.color_palette()[2] if word == current_word else sns.color_palette()[0] | |
| for word in predictions_df["word"] | |
| ] | |
| if current_word == predictions_df["word"][0]: | |
| gr.Info("AI guessed the word correctly! π") | |
| sns.barplot( | |
| data=predictions_df, | |
| y="word", | |
| x="prob", | |
| hue="word", | |
| orient="h", | |
| palette=colors, | |
| legend=False, | |
| ax=ax, | |
| ) | |
| ax.set_title("Top 10 Predictions", pad=15) | |
| ax.set_xlabel("Probability") | |
| ax.set_ylabel(None) | |
| return fig, current_word | |
| def create_initial_image(): | |
| data = np.full((520, 700, 3), 255, dtype=np.uint8) # White image | |
| return Image.fromarray(data) | |
| # Interface definition | |
| with gr.Blocks(theme=gr.themes.Soft(), css="input {font-size: 24px; font-weight: 600;}") as demo_app: | |
| gr.Markdown("# Scribble It! AI Demo π¨") | |
| gr.Markdown("### Draw the word shown below and let the AI guess what it is!") | |
| with gr.Row(): | |
| word_output = gr.Textbox(label="Your word to draw:", | |
| value=get_random_word, | |
| interactive=False, | |
| scale=1, | |
| max_lines=1) | |
| new_word_button = gr.Button("New Word", scale=0, variant="primary") | |
| with gr.Row(): | |
| image_editor = gr.ImageEditor(label="Draw Here!", | |
| image_mode="RGB", | |
| type="pil", | |
| sources=[], | |
| transforms=[], | |
| layers=False, | |
| value=create_initial_image, | |
| brush=gr.Brush(colors=["#000000", "#4c72b0", "#55a868", "#d62728"], | |
| default_size=5)) | |
| plot_output = gr.Plot(label="Model Guesses") | |
| image_editor.clear(create_initial_image, outputs=image_editor) | |
| image_editor.change(process_image, inputs=[image_editor, word_output], outputs=[plot_output, word_output]) | |
| new_word_button.click(get_random_word, outputs=word_output) | |
| new_word_button.click(create_initial_image, outputs=image_editor) | |
| demo_app.launch() | |