Philipp Normann commited on
Commit
399e690
·
1 Parent(s): 71cad19

Words in vocab are already lowercase

Browse files
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -2,10 +2,10 @@ import os
2
  import random
3
 
4
  import gradio as gr
5
- import seaborn as sns
6
  import matplotlib.pyplot as plt
7
  import numpy as np
8
  import polars as pl
 
9
  import torch
10
  from huggingface_hub import hf_hub_download
11
  from PIL import Image
@@ -35,16 +35,6 @@ def load_model():
35
  return model
36
 
37
 
38
- model = load_model()
39
-
40
- # Transform configuration
41
- transform = v2.Compose([
42
- v2.Resize((224, 224)),
43
- v2.ToDtype(torch.float32, scale=True),
44
- v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
45
- ])
46
-
47
-
48
  # Load vocabulary
49
  def load_vocabulary():
50
  hf_hub_download("ScribbleItAI/efficientnet-b0",
@@ -64,6 +54,14 @@ def compute_word_weights(vocabulary):
64
  return words, weights
65
 
66
 
 
 
 
 
 
 
 
 
67
  vocabulary = load_vocabulary()
68
  words, weights = compute_word_weights(vocabulary)
69
 
@@ -93,8 +91,6 @@ def process_image(image, current_word):
93
  })
94
 
95
  predictions_df = pl.DataFrame(predictions)
96
- predictions_df = predictions_df.with_columns(
97
- pl.col("word").str.to_lowercase())
98
  predictions_df = predictions_df.group_by("word").agg(
99
  pl.col("prob").max().alias("prob"))
100
  predictions_df = predictions_df.sort("prob", descending=True).head(10)
 
2
  import random
3
 
4
  import gradio as gr
 
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
  import polars as pl
8
+ import seaborn as sns
9
  import torch
10
  from huggingface_hub import hf_hub_download
11
  from PIL import Image
 
35
  return model
36
 
37
 
 
 
 
 
 
 
 
 
 
 
38
  # Load vocabulary
39
  def load_vocabulary():
40
  hf_hub_download("ScribbleItAI/efficientnet-b0",
 
54
  return words, weights
55
 
56
 
57
+ model = load_model()
58
+
59
+ transform = v2.Compose([
60
+ v2.Resize((224, 224)),
61
+ v2.ToDtype(torch.float32, scale=True),
62
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
63
+ ])
64
+
65
  vocabulary = load_vocabulary()
66
  words, weights = compute_word_weights(vocabulary)
67
 
 
91
  })
92
 
93
  predictions_df = pl.DataFrame(predictions)
 
 
94
  predictions_df = predictions_df.group_by("word").agg(
95
  pl.col("prob").max().alias("prob"))
96
  predictions_df = predictions_df.sort("prob", descending=True).head(10)