Philipp Normann commited on
Commit
7d3432f
Β·
1 Parent(s): 628ad7b

Use seaborn color palette

Browse files
Files changed (4) hide show
  1. Pipfile +1 -0
  2. Pipfile.lock +10 -1
  3. app.py +22 -6
  4. requirements.txt +1 -0
Pipfile CHANGED
@@ -13,6 +13,7 @@ torch = "==2.3.1"
13
  torchvision = "==0.18.1"
14
  lightning = "==2.3.0"
15
  gradio = "*"
 
16
 
17
  [dev-packages]
18
 
 
13
  torchvision = "==0.18.1"
14
  lightning = "==2.3.0"
15
  gradio = "*"
16
+ seaborn = "*"
17
 
18
  [dev-packages]
19
 
Pipfile.lock CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_meta": {
3
  "hash": {
4
- "sha256": "89d094e3bcfe678655d839dafae5d0ca2cea12d8ff08ee98f9ec2c72bd372138"
5
  },
6
  "pipfile-spec": 6,
7
  "requires": {
@@ -1772,6 +1772,15 @@
1772
  "markers": "sys_platform != 'emscripten'",
1773
  "version": "==0.4.9"
1774
  },
 
 
 
 
 
 
 
 
 
1775
  "semantic-version": {
1776
  "hashes": [
1777
  "sha256:bdabb6d336998cbb378d4b9db3a4b56a1e3235701dc05ea2690d9a997ed5041c",
 
1
  {
2
  "_meta": {
3
  "hash": {
4
+ "sha256": "6cd84f80a3c605c6c74d9229abcd62e631682c99fa656bdcb81e996e3bb9d715"
5
  },
6
  "pipfile-spec": 6,
7
  "requires": {
 
1772
  "markers": "sys_platform != 'emscripten'",
1773
  "version": "==0.4.9"
1774
  },
1775
+ "seaborn": {
1776
+ "hashes": [
1777
+ "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987",
1778
+ "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7"
1779
+ ],
1780
+ "index": "pypi",
1781
+ "markers": "python_version >= '3.8'",
1782
+ "version": "==0.13.2"
1783
+ },
1784
  "semantic-version": {
1785
  "hashes": [
1786
  "sha256:bdabb6d336998cbb378d4b9db3a4b56a1e3235701dc05ea2690d9a997ed5041c",
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import random
3
 
4
  import gradio as gr
 
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
  import polars as pl
@@ -12,6 +13,9 @@ from torchvision.transforms import v2
12
 
13
  from model import ScribbleItNet
14
 
 
 
 
15
  # Matplotlib configuration
16
  plt.rc('font', size=16)
17
  plt.rc('axes', titlesize=24)
@@ -92,22 +96,34 @@ def process_image(image, current_word):
92
  predictions_df = predictions_df.with_columns(
93
  pl.col("word").str.to_lowercase())
94
  predictions_df = predictions_df.group_by("word").agg(
95
- pl.col("prob").max().alias("prob")).sort("prob").tail(10)
 
96
 
97
  # Visualizing predictions
98
  fig, ax = plt.subplots(figsize=(10, 8))
99
  plt.subplots_adjust(left=0.25, top=0.9, right=0.9, bottom=0.1)
100
  colors = [
101
- "green" if word == current_word else "tab:blue"
 
102
  for word in predictions_df["word"]
103
  ]
104
 
105
  if current_word in predictions_df["word"]:
106
  gr.Info("AI guessed the word correctly! πŸŽ‰")
107
 
108
- ax.barh(predictions_df["word"], predictions_df["prob"], color=colors)
 
 
 
 
 
 
 
 
 
109
  ax.set_title("Top 10 Predictions", pad=15)
110
  ax.set_xlabel("Probability")
 
111
  plt.close(fig)
112
  return fig, current_word
113
 
@@ -118,7 +134,7 @@ def update_image(image):
118
 
119
 
120
  def create_initial_image():
121
- data = np.full((500, 700, 3), 255, dtype=np.uint8) # White image
122
  return Image.fromarray(data)
123
 
124
 
@@ -126,7 +142,7 @@ def create_initial_image():
126
  with gr.Blocks(theme=gr.themes.Soft(),
127
  css="input {font-size: 24px; font-weight: 600;}") as demo_app:
128
  gr.Markdown("# Scribble It! AI Demo 🎨")
129
- gr.Markdown("### Draw the word shown and let the AI guess what it is!")
130
 
131
  with gr.Row():
132
  word_output = gr.Textbox(label="Your word to draw:",
@@ -144,7 +160,7 @@ with gr.Blocks(theme=gr.themes.Soft(),
144
  transforms=[],
145
  layers=False,
146
  value=create_initial_image,
147
- brush=gr.Brush(colors=["#000000", "#FF0000", "#00FF00", "#0000FF"],
148
  default_size=10))
149
  plot_output = gr.Plot(label="Model Guesses")
150
 
 
2
  import random
3
 
4
  import gradio as gr
5
+ import seaborn as sns
6
  import matplotlib.pyplot as plt
7
  import numpy as np
8
  import polars as pl
 
13
 
14
  from model import ScribbleItNet
15
 
16
+ # Seaborn configuration
17
+ sns.set_theme()
18
+
19
  # Matplotlib configuration
20
  plt.rc('font', size=16)
21
  plt.rc('axes', titlesize=24)
 
96
  predictions_df = predictions_df.with_columns(
97
  pl.col("word").str.to_lowercase())
98
  predictions_df = predictions_df.group_by("word").agg(
99
+ pl.col("prob").max().alias("prob"))
100
+ predictions_df = predictions_df.sort("prob", descending=True).head(10)
101
 
102
  # Visualizing predictions
103
  fig, ax = plt.subplots(figsize=(10, 8))
104
  plt.subplots_adjust(left=0.25, top=0.9, right=0.9, bottom=0.1)
105
  colors = [
106
+ sns.color_palette()[2]
107
+ if word == current_word else sns.color_palette()[0]
108
  for word in predictions_df["word"]
109
  ]
110
 
111
  if current_word in predictions_df["word"]:
112
  gr.Info("AI guessed the word correctly! πŸŽ‰")
113
 
114
+ sns.barplot(
115
+ data=predictions_df,
116
+ y="word",
117
+ x="prob",
118
+ hue="word",
119
+ orient="h",
120
+ palette=colors,
121
+ legend=False,
122
+ ax=ax,
123
+ )
124
  ax.set_title("Top 10 Predictions", pad=15)
125
  ax.set_xlabel("Probability")
126
+ ax.set_ylabel(None)
127
  plt.close(fig)
128
  return fig, current_word
129
 
 
134
 
135
 
136
  def create_initial_image():
137
+ data = np.full((520, 700, 3), 255, dtype=np.uint8) # White image
138
  return Image.fromarray(data)
139
 
140
 
 
142
  with gr.Blocks(theme=gr.themes.Soft(),
143
  css="input {font-size: 24px; font-weight: 600;}") as demo_app:
144
  gr.Markdown("# Scribble It! AI Demo 🎨")
145
+ gr.Markdown("### Draw the word shown below and let the AI guess what it is!")
146
 
147
  with gr.Row():
148
  word_output = gr.Textbox(label="Your word to draw:",
 
160
  transforms=[],
161
  layers=False,
162
  value=create_initial_image,
163
+ brush=gr.Brush(colors=["#000000", "#4c72b0", "#55a868", "#d62728"],
164
  default_size=10))
165
  plot_output = gr.Plot(label="Model Guesses")
166
 
requirements.txt CHANGED
@@ -2,6 +2,7 @@ gradio==4.36.1; python_version >= '3.8'
2
  huggingface-hub==0.23.4; python_version >= '3.8'
3
  polars==0.20.31; python_version >= '3.8'
4
  matplotlib==3.9.0; python_version >= '3.9'
 
5
  torch==2.3.1; python_full_version >= '3.8.0'
6
  torchvision==0.18.1; python_version >= '3.8'
7
  lightning==2.3.0; python_version >= '3.8'
 
2
  huggingface-hub==0.23.4; python_version >= '3.8'
3
  polars==0.20.31; python_version >= '3.8'
4
  matplotlib==3.9.0; python_version >= '3.9'
5
+ seaborn==0.13.2; python_version >= '3.8'
6
  torch==2.3.1; python_full_version >= '3.8.0'
7
  torchvision==0.18.1; python_version >= '3.8'
8
  lightning==2.3.0; python_version >= '3.8'