Philipp Normann commited on
Commit
a0c5ceb
·
1 Parent(s): 69c5e5e

Migrate to gradio

Browse files
Files changed (6) hide show
  1. .streamlit/config.toml +0 -7
  2. Pipfile +20 -0
  3. Pipfile.lock +0 -0
  4. README.md +2 -2
  5. app.py +87 -81
  6. requirements.txt +1 -2
.streamlit/config.toml DELETED
@@ -1,7 +0,0 @@
1
- [client]
2
- showErrorDetails = false
3
-
4
- [theme]
5
- primaryColor = "#3a6ef3bd"
6
- backgroundColor = "#edbc41"
7
- textColor="#000000"
 
 
 
 
 
 
 
 
Pipfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[source]]
2
+ url = "https://pypi.org/simple"
3
+ verify_ssl = true
4
+ name = "pypi"
5
+
6
+ [packages]
7
+ streamlit = "==1.35.0"
8
+ streamlit-drawable-canvas = "==0.9.3"
9
+ huggingface-hub = "==0.23.4"
10
+ polars = "==0.20.31"
11
+ matplotlib = "==3.9.0"
12
+ torch = "==2.3.1"
13
+ torchvision = "==0.18.1"
14
+ lightning = "==2.3.0"
15
+ gradio = "*"
16
+
17
+ [dev-packages]
18
+
19
+ [requires]
20
+ python_version = "3.12"
Pipfile.lock ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -3,8 +3,8 @@ title: "Scribble It! AI Demo"
3
  emoji: 🎨
4
  colorFrom: yellow
5
  colorTo: blue
6
- sdk: streamlit
7
- sdk_version: 1.35.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
3
  emoji: 🎨
4
  colorFrom: yellow
5
  colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,42 +1,25 @@
1
  import os
2
  import random
3
 
 
 
 
4
  import polars as pl
5
- import streamlit as st
6
  import torch
7
  from huggingface_hub import hf_hub_download
8
- from matplotlib import pyplot as plt
9
  from PIL import Image
10
- from streamlit_drawable_canvas import st_canvas
11
  from torchvision.transforms import v2
12
 
13
  from model import ScribbleItNet
14
 
15
- # Page configuration
16
- st.set_page_config(page_title="Scribble It! AI Demo 🎨")
17
- st.title("Scribble It! AI Demo 🎨")
18
-
19
- # Set the background image
20
- background_image = """
21
- <style>
22
- [data-testid="stAppViewContainer"] > .main {
23
- background-image: url("https://detach-entertainment.com/img/clouds.ed95a9c8.svg");
24
- background-color: #edbc41;
25
- }
26
- </style>
27
- """
28
-
29
- st.markdown(background_image, unsafe_allow_html=True)
30
-
31
  # Matplotlib configuration
32
  plt.rc('font', size=16)
33
- plt.rc('axes', titlesize=16)
34
- plt.rc('xtick', labelsize=24)
35
  plt.rc('ytick', labelsize=20)
36
 
37
 
38
  # Load the model
39
- @st.cache_resource
40
  def load_model():
41
  hf_hub_download("ScribbleItAI/efficientnet-b0",
42
  token=os.getenv("HF_TOKEN"),
@@ -49,6 +32,8 @@ def load_model():
49
 
50
 
51
  model = load_model()
 
 
52
  transform = v2.Compose([
53
  v2.Resize((224, 224)),
54
  v2.ToDtype(torch.float32, scale=True),
@@ -56,8 +41,7 @@ transform = v2.Compose([
56
  ])
57
 
58
 
59
- # Load the vocabulary
60
- @st.cache_data
61
  def load_vocabulary():
62
  hf_hub_download("ScribbleItAI/efficientnet-b0",
63
  token=os.getenv("HF_TOKEN"),
@@ -68,49 +52,19 @@ def load_vocabulary():
68
 
69
  vocabulary = load_vocabulary()
70
  idx2vocab = {row["word_idx"]: row for row in vocabulary}
71
- vocabulary = {
72
- f"{row['word_hash']}_{row['category_idx']}": row
73
- for row in vocabulary
74
- }
75
-
76
- # Sidebar
77
- with st.sidebar:
78
- drawing_mode = st.selectbox(
79
- "Drawing tool:",
80
- ("freedraw", "point", "line", "rect", "circle", "transform"))
81
-
82
- stroke_width = st.slider("Stroke width: ", 1, 25, 5)
83
- if drawing_mode == 'point':
84
- point_display_radius = st.slider("Point display radius: ", 1, 25, 5)
85
- stroke_color = st.color_picker("Stroke color hex: ")
86
- bg_color = st.color_picker("Background color hex: ", "#ffffff")
87
- realtime_update = st.checkbox("Update in realtime", True)
88
-
89
- if st.button("New word") or "sample" not in st.session_state:
90
- st.session_state.sample = random.choice(list(vocabulary.values()))
91
-
92
- st.markdown(f" Draw a: **{st.session_state.sample['word']}**")
93
-
94
- # Canvas
95
- canvas_result = st_canvas(
96
- stroke_width=stroke_width,
97
- stroke_color=stroke_color,
98
- background_color=bg_color,
99
- update_streamlit=realtime_update,
100
- height=500,
101
- width=800,
102
- drawing_mode=drawing_mode,
103
- point_display_radius=point_display_radius
104
- if drawing_mode == 'point' else 0,
105
- key="canvas",
106
- )
107
-
108
- # Inference
109
- if canvas_result.image_data is not None:
110
- img = canvas_result.image_data
111
- img = torch.tensor(img)[:, :, :3].permute(2, 0, 1)
112
- img = transform(img)
113
- outputs = model(img.unsqueeze(0).to(model.device))
114
  outputs = torch.softmax(outputs, dim=1)
115
  preds, indices = outputs.topk(100, dim=1)
116
 
@@ -125,19 +79,71 @@ if canvas_result.image_data is not None:
125
  "category": vocab["category_name"],
126
  "prob": pred
127
  })
128
- predictions = pl.DataFrame(predictions)
129
- predictions = predictions.group_by("word").agg(
130
- pl.col("prob").max().alias("prob"))
131
- predictions = predictions.sort("prob")
132
- predictions = predictions.tail(10)
133
-
134
- # Plot the predictions
135
- fig = plt.figure(figsize=(14, 10))
 
 
136
  colors = [
137
- "green" if word == st.session_state.sample["word"] else "tab:blue"
138
- for word in predictions["word"]
139
  ]
140
- plt.barh(predictions["word"], predictions["prob"], color=colors)
141
- plt.xlabel("Probability")
142
- plt.ylabel("Word")
143
- st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import random
3
 
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
  import polars as pl
 
8
  import torch
9
  from huggingface_hub import hf_hub_download
 
10
  from PIL import Image
 
11
  from torchvision.transforms import v2
12
 
13
  from model import ScribbleItNet
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Matplotlib configuration
16
  plt.rc('font', size=16)
17
+ plt.rc('axes', titlesize=24)
18
+ plt.rc('xtick', labelsize=20)
19
  plt.rc('ytick', labelsize=20)
20
 
21
 
22
  # Load the model
 
23
  def load_model():
24
  hf_hub_download("ScribbleItAI/efficientnet-b0",
25
  token=os.getenv("HF_TOKEN"),
 
32
 
33
 
34
  model = load_model()
35
+
36
+ # Transform configuration
37
  transform = v2.Compose([
38
  v2.Resize((224, 224)),
39
  v2.ToDtype(torch.float32, scale=True),
 
41
  ])
42
 
43
 
44
+ # Load vocabulary
 
45
  def load_vocabulary():
46
  hf_hub_download("ScribbleItAI/efficientnet-b0",
47
  token=os.getenv("HF_TOKEN"),
 
52
 
53
  vocabulary = load_vocabulary()
54
  idx2vocab = {row["word_idx"]: row for row in vocabulary}
55
+ vocab_list = [row["word"] for row in vocabulary]
56
+
57
+
58
+ # Select a random word
59
+ def get_random_word():
60
+ return random.choice(vocab_list)
61
+
62
+
63
+ # Process the image drawn on canvas
64
+ def process_image(image, current_word):
65
+ img_tensor = torch.tensor(image["composite"]).permute(2, 0, 1)
66
+ img_tensor = transform(img_tensor)
67
+ outputs = model(img_tensor.unsqueeze(0).to(model.device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  outputs = torch.softmax(outputs, dim=1)
69
  preds, indices = outputs.topk(100, dim=1)
70
 
 
79
  "category": vocab["category_name"],
80
  "prob": pred
81
  })
82
+
83
+ predictions_df = pl.DataFrame(predictions)
84
+ predictions_df = predictions_df.with_columns(
85
+ pl.col("word").str.to_lowercase())
86
+ predictions_df = predictions_df.group_by("word").agg(
87
+ pl.col("prob").max().alias("prob")).sort("prob").tail(10)
88
+
89
+ # Visualizing predictions
90
+ fig, ax = plt.subplots(figsize=(10, 8))
91
+ plt.subplots_adjust(left=0.25, top=0.9, right=0.9, bottom=0.1)
92
  colors = [
93
+ "green" if word == current_word else "tab:blue"
94
+ for word in predictions_df["word"]
95
  ]
96
+ ax.barh(predictions_df["word"], predictions_df["prob"], color=colors)
97
+ ax.set_title("Top 10 Predictions", pad=15)
98
+ ax.set_xlabel("Probability")
99
+ plt.close(fig)
100
+ return fig, current_word
101
+
102
+
103
+ def update_image(image):
104
+ image = Image.fromarray(image["composite"])
105
+ return image
106
+
107
+
108
+ def create_initial_image():
109
+ data = np.full((500, 700, 3), 255, dtype=np.uint8) # White image
110
+ return Image.fromarray(data)
111
+
112
+
113
+ # Create a white image with the dimensions for the ImageEditor
114
+ initial_image = create_initial_image
115
+
116
+ # Interface definition
117
+ with gr.Blocks(theme=gr.themes.Soft(),
118
+ css="input {font-size: 24px; font-weight: 600;}") as demo_app:
119
+ gr.Markdown("# Scribble It! AI Demo 🎨")
120
+ gr.Markdown("### Draw the word shown and let the AI guess what it is!")
121
+
122
+ with gr.Row():
123
+ word_output = gr.Textbox(label="Your word to draw:",
124
+ value=get_random_word(),
125
+ scale=1,
126
+ max_lines=1)
127
+ new_word_button = gr.Button("New Word", scale=0, variant="primary")
128
+
129
+ with gr.Row():
130
+ image_editor = gr.ImageEditor(
131
+ label="Draw Here!",
132
+ image_mode="RGB",
133
+ sources=[],
134
+ transforms=[],
135
+ layers=False,
136
+ value=initial_image,
137
+ brush=gr.Brush(colors=["#000000", "#FF0000", "#00FF00", "#0000FF"],
138
+ default_size=10))
139
+ plot_output = gr.Plot(label="Model Guesses")
140
+
141
+ image_editor.clear(create_initial_image, outputs=image_editor)
142
+ image_editor.change(process_image,
143
+ inputs=[image_editor, word_output],
144
+ outputs=[plot_output, word_output])
145
+
146
+ new_word_button.click(get_random_word, outputs=word_output)
147
+ new_word_button.click(create_initial_image, outputs=image_editor)
148
+
149
+ demo_app.launch()
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
- streamlit==1.35.0; python_version >= '3.8' and python_full_version != '3.9.7'
2
- streamlit-drawable-canvas==0.9.3; python_version >= '3.6'
3
  huggingface-hub==0.23.4; python_version >= '3.8'
4
  polars==0.20.31; python_version >= '3.8'
5
  matplotlib==3.9.0; python_version >= '3.9'
 
1
+ gradio==4.36.1; python_version >= '3.8'
 
2
  huggingface-hub==0.23.4; python_version >= '3.8'
3
  polars==0.20.31; python_version >= '3.8'
4
  matplotlib==3.9.0; python_version >= '3.9'