Spaces:
Sleeping
Sleeping
Philipp Normann
commited on
Commit
Β·
7d3432f
1
Parent(s):
628ad7b
Use seaborn color palette
Browse files- Pipfile +1 -0
- Pipfile.lock +10 -1
- app.py +22 -6
- requirements.txt +1 -0
Pipfile
CHANGED
|
@@ -13,6 +13,7 @@ torch = "==2.3.1"
|
|
| 13 |
torchvision = "==0.18.1"
|
| 14 |
lightning = "==2.3.0"
|
| 15 |
gradio = "*"
|
|
|
|
| 16 |
|
| 17 |
[dev-packages]
|
| 18 |
|
|
|
|
| 13 |
torchvision = "==0.18.1"
|
| 14 |
lightning = "==2.3.0"
|
| 15 |
gradio = "*"
|
| 16 |
+
seaborn = "*"
|
| 17 |
|
| 18 |
[dev-packages]
|
| 19 |
|
Pipfile.lock
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
{
|
| 2 |
"_meta": {
|
| 3 |
"hash": {
|
| 4 |
-
"sha256": "
|
| 5 |
},
|
| 6 |
"pipfile-spec": 6,
|
| 7 |
"requires": {
|
|
@@ -1772,6 +1772,15 @@
|
|
| 1772 |
"markers": "sys_platform != 'emscripten'",
|
| 1773 |
"version": "==0.4.9"
|
| 1774 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1775 |
"semantic-version": {
|
| 1776 |
"hashes": [
|
| 1777 |
"sha256:bdabb6d336998cbb378d4b9db3a4b56a1e3235701dc05ea2690d9a997ed5041c",
|
|
|
|
| 1 |
{
|
| 2 |
"_meta": {
|
| 3 |
"hash": {
|
| 4 |
+
"sha256": "6cd84f80a3c605c6c74d9229abcd62e631682c99fa656bdcb81e996e3bb9d715"
|
| 5 |
},
|
| 6 |
"pipfile-spec": 6,
|
| 7 |
"requires": {
|
|
|
|
| 1772 |
"markers": "sys_platform != 'emscripten'",
|
| 1773 |
"version": "==0.4.9"
|
| 1774 |
},
|
| 1775 |
+
"seaborn": {
|
| 1776 |
+
"hashes": [
|
| 1777 |
+
"sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987",
|
| 1778 |
+
"sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7"
|
| 1779 |
+
],
|
| 1780 |
+
"index": "pypi",
|
| 1781 |
+
"markers": "python_version >= '3.8'",
|
| 1782 |
+
"version": "==0.13.2"
|
| 1783 |
+
},
|
| 1784 |
"semantic-version": {
|
| 1785 |
"hashes": [
|
| 1786 |
"sha256:bdabb6d336998cbb378d4b9db3a4b56a1e3235701dc05ea2690d9a997ed5041c",
|
app.py
CHANGED
|
@@ -2,6 +2,7 @@ import os
|
|
| 2 |
import random
|
| 3 |
|
| 4 |
import gradio as gr
|
|
|
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import numpy as np
|
| 7 |
import polars as pl
|
|
@@ -12,6 +13,9 @@ from torchvision.transforms import v2
|
|
| 12 |
|
| 13 |
from model import ScribbleItNet
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
# Matplotlib configuration
|
| 16 |
plt.rc('font', size=16)
|
| 17 |
plt.rc('axes', titlesize=24)
|
|
@@ -92,22 +96,34 @@ def process_image(image, current_word):
|
|
| 92 |
predictions_df = predictions_df.with_columns(
|
| 93 |
pl.col("word").str.to_lowercase())
|
| 94 |
predictions_df = predictions_df.group_by("word").agg(
|
| 95 |
-
pl.col("prob").max().alias("prob"))
|
|
|
|
| 96 |
|
| 97 |
# Visualizing predictions
|
| 98 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 99 |
plt.subplots_adjust(left=0.25, top=0.9, right=0.9, bottom=0.1)
|
| 100 |
colors = [
|
| 101 |
-
|
|
|
|
| 102 |
for word in predictions_df["word"]
|
| 103 |
]
|
| 104 |
|
| 105 |
if current_word in predictions_df["word"]:
|
| 106 |
gr.Info("AI guessed the word correctly! π")
|
| 107 |
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
ax.set_title("Top 10 Predictions", pad=15)
|
| 110 |
ax.set_xlabel("Probability")
|
|
|
|
| 111 |
plt.close(fig)
|
| 112 |
return fig, current_word
|
| 113 |
|
|
@@ -118,7 +134,7 @@ def update_image(image):
|
|
| 118 |
|
| 119 |
|
| 120 |
def create_initial_image():
|
| 121 |
-
data = np.full((
|
| 122 |
return Image.fromarray(data)
|
| 123 |
|
| 124 |
|
|
@@ -126,7 +142,7 @@ def create_initial_image():
|
|
| 126 |
with gr.Blocks(theme=gr.themes.Soft(),
|
| 127 |
css="input {font-size: 24px; font-weight: 600;}") as demo_app:
|
| 128 |
gr.Markdown("# Scribble It! AI Demo π¨")
|
| 129 |
-
gr.Markdown("### Draw the word shown and let the AI guess what it is!")
|
| 130 |
|
| 131 |
with gr.Row():
|
| 132 |
word_output = gr.Textbox(label="Your word to draw:",
|
|
@@ -144,7 +160,7 @@ with gr.Blocks(theme=gr.themes.Soft(),
|
|
| 144 |
transforms=[],
|
| 145 |
layers=False,
|
| 146 |
value=create_initial_image,
|
| 147 |
-
brush=gr.Brush(colors=["#000000", "#
|
| 148 |
default_size=10))
|
| 149 |
plot_output = gr.Plot(label="Model Guesses")
|
| 150 |
|
|
|
|
| 2 |
import random
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
+
import seaborn as sns
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
import numpy as np
|
| 8 |
import polars as pl
|
|
|
|
| 13 |
|
| 14 |
from model import ScribbleItNet
|
| 15 |
|
| 16 |
+
# Seaborn configuration
|
| 17 |
+
sns.set_theme()
|
| 18 |
+
|
| 19 |
# Matplotlib configuration
|
| 20 |
plt.rc('font', size=16)
|
| 21 |
plt.rc('axes', titlesize=24)
|
|
|
|
| 96 |
predictions_df = predictions_df.with_columns(
|
| 97 |
pl.col("word").str.to_lowercase())
|
| 98 |
predictions_df = predictions_df.group_by("word").agg(
|
| 99 |
+
pl.col("prob").max().alias("prob"))
|
| 100 |
+
predictions_df = predictions_df.sort("prob", descending=True).head(10)
|
| 101 |
|
| 102 |
# Visualizing predictions
|
| 103 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 104 |
plt.subplots_adjust(left=0.25, top=0.9, right=0.9, bottom=0.1)
|
| 105 |
colors = [
|
| 106 |
+
sns.color_palette()[2]
|
| 107 |
+
if word == current_word else sns.color_palette()[0]
|
| 108 |
for word in predictions_df["word"]
|
| 109 |
]
|
| 110 |
|
| 111 |
if current_word in predictions_df["word"]:
|
| 112 |
gr.Info("AI guessed the word correctly! π")
|
| 113 |
|
| 114 |
+
sns.barplot(
|
| 115 |
+
data=predictions_df,
|
| 116 |
+
y="word",
|
| 117 |
+
x="prob",
|
| 118 |
+
hue="word",
|
| 119 |
+
orient="h",
|
| 120 |
+
palette=colors,
|
| 121 |
+
legend=False,
|
| 122 |
+
ax=ax,
|
| 123 |
+
)
|
| 124 |
ax.set_title("Top 10 Predictions", pad=15)
|
| 125 |
ax.set_xlabel("Probability")
|
| 126 |
+
ax.set_ylabel(None)
|
| 127 |
plt.close(fig)
|
| 128 |
return fig, current_word
|
| 129 |
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
def create_initial_image():
|
| 137 |
+
data = np.full((520, 700, 3), 255, dtype=np.uint8) # White image
|
| 138 |
return Image.fromarray(data)
|
| 139 |
|
| 140 |
|
|
|
|
| 142 |
with gr.Blocks(theme=gr.themes.Soft(),
|
| 143 |
css="input {font-size: 24px; font-weight: 600;}") as demo_app:
|
| 144 |
gr.Markdown("# Scribble It! AI Demo π¨")
|
| 145 |
+
gr.Markdown("### Draw the word shown below and let the AI guess what it is!")
|
| 146 |
|
| 147 |
with gr.Row():
|
| 148 |
word_output = gr.Textbox(label="Your word to draw:",
|
|
|
|
| 160 |
transforms=[],
|
| 161 |
layers=False,
|
| 162 |
value=create_initial_image,
|
| 163 |
+
brush=gr.Brush(colors=["#000000", "#4c72b0", "#55a868", "#d62728"],
|
| 164 |
default_size=10))
|
| 165 |
plot_output = gr.Plot(label="Model Guesses")
|
| 166 |
|
requirements.txt
CHANGED
|
@@ -2,6 +2,7 @@ gradio==4.36.1; python_version >= '3.8'
|
|
| 2 |
huggingface-hub==0.23.4; python_version >= '3.8'
|
| 3 |
polars==0.20.31; python_version >= '3.8'
|
| 4 |
matplotlib==3.9.0; python_version >= '3.9'
|
|
|
|
| 5 |
torch==2.3.1; python_full_version >= '3.8.0'
|
| 6 |
torchvision==0.18.1; python_version >= '3.8'
|
| 7 |
lightning==2.3.0; python_version >= '3.8'
|
|
|
|
| 2 |
huggingface-hub==0.23.4; python_version >= '3.8'
|
| 3 |
polars==0.20.31; python_version >= '3.8'
|
| 4 |
matplotlib==3.9.0; python_version >= '3.9'
|
| 5 |
+
seaborn==0.13.2; python_version >= '3.8'
|
| 6 |
torch==2.3.1; python_full_version >= '3.8.0'
|
| 7 |
torchvision==0.18.1; python_version >= '3.8'
|
| 8 |
lightning==2.3.0; python_version >= '3.8'
|