Philipp Normann commited on
Commit
970bed6
·
1 Parent(s): 90918bb

Argsort operation is now integrated into the ONNX model

Browse files
Files changed (1) hide show
  1. app.py +6 -15
app.py CHANGED
@@ -4,11 +4,11 @@ import random
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
  import numpy as np
 
7
  import polars as pl
8
  import seaborn as sns
9
  from huggingface_hub import hf_hub_download
10
  from PIL import Image
11
- import onnxruntime as ort
12
 
13
  # Seaborn configuration
14
  sns.set_theme()
@@ -43,12 +43,12 @@ def compute_word_weights(vocabulary):
43
  total_train_count = train_counts["train_count"].sum()
44
  word_weights = [(vocab["word"], vocab["train_count"] / total_train_count)
45
  for vocab in vocabulary.rows(named=True)]
46
- words = [word for word, _ in word_weights]
47
- weights = [weight for _, weight in word_weights]
48
- return words, weights
49
 
50
 
51
  ort_session = load_model()
 
 
52
  vocabulary = load_vocabulary()
53
  words, weights = compute_word_weights(vocabulary)
54
 
@@ -60,13 +60,10 @@ def get_random_word():
60
  # Process the image drawn on canvas
61
  def process_image(image, current_word):
62
  input_img = image["composite"].resize((224, 224))
63
- inputs = {ort_session.get_inputs()[0].name: np.array(input_img)}
64
- outputs = ort_session.run(None, inputs)[0]
65
- indices_i = np.argsort(outputs)[::-1]
66
- preds_i = outputs[indices_i]
67
 
68
  predictions = []
69
- for pred, idx in zip(preds_i, indices_i):
70
  vocab = vocabulary.row(idx, named=True)
71
  predictions.append({"word": vocab["word"], "category": vocab["category_name"], "prob": pred})
72
 
@@ -98,15 +95,9 @@ def process_image(image, current_word):
98
  ax.set_title("Top 10 Predictions", pad=15)
99
  ax.set_xlabel("Probability")
100
  ax.set_ylabel(None)
101
- plt.close(fig)
102
  return fig, current_word
103
 
104
 
105
- def update_image(image):
106
- image = Image.fromarray(image["composite"])
107
- return image
108
-
109
-
110
  def create_initial_image():
111
  data = np.full((520, 700, 3), 255, dtype=np.uint8) # White image
112
  return Image.fromarray(data)
 
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
+ import onnxruntime as ort
8
  import polars as pl
9
  import seaborn as sns
10
  from huggingface_hub import hf_hub_download
11
  from PIL import Image
 
12
 
13
  # Seaborn configuration
14
  sns.set_theme()
 
43
  total_train_count = train_counts["train_count"].sum()
44
  word_weights = [(vocab["word"], vocab["train_count"] / total_train_count)
45
  for vocab in vocabulary.rows(named=True)]
46
+ return [word for word, _ in word_weights], [weight for _, weight in word_weights]
 
 
47
 
48
 
49
  ort_session = load_model()
50
+ input_name = ort_session.get_inputs()[0].name
51
+
52
  vocabulary = load_vocabulary()
53
  words, weights = compute_word_weights(vocabulary)
54
 
 
60
  # Process the image drawn on canvas
61
  def process_image(image, current_word):
62
  input_img = image["composite"].resize((224, 224))
63
+ indices, preds = ort_session.run(None, {input_name: np.array(input_img)})
 
 
 
64
 
65
  predictions = []
66
+ for pred, idx in zip(preds, indices):
67
  vocab = vocabulary.row(idx, named=True)
68
  predictions.append({"word": vocab["word"], "category": vocab["category_name"], "prob": pred})
69
 
 
95
  ax.set_title("Top 10 Predictions", pad=15)
96
  ax.set_xlabel("Probability")
97
  ax.set_ylabel(None)
 
98
  return fig, current_word
99
 
100
 
 
 
 
 
 
101
  def create_initial_image():
102
  data = np.full((520, 700, 3), 255, dtype=np.uint8) # White image
103
  return Image.fromarray(data)