rianders commited on
Commit
6ef9d55
·
verified ·
1 Parent(s): 3e8771c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -11
app.py CHANGED
@@ -3,8 +3,9 @@ from transformers import AutoModel, BertTokenizer
3
  from sklearn.decomposition import PCA
4
  import plotly.graph_objs as go
5
  import numpy as np
 
6
 
7
- # Initialize tokenizer and model globally if possible to avoid reloading it on every function call
8
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
9
  model = AutoModel.from_pretrained('bert-base-uncased')
10
 
@@ -13,8 +14,8 @@ def get_bert_embeddings(words):
13
  for word in words:
14
  inputs = tokenizer(word, return_tensors='pt')
15
  outputs = model(**inputs)
16
- mean_embedding = outputs.last_hidden_state[0].mean(dim=0).detach().numpy()
17
- embeddings.append(mean_embedding)
18
  return embeddings
19
 
20
  def plot_interactive_bert_embeddings(embeddings, words):
@@ -48,23 +49,40 @@ def plot_interactive_bert_embeddings(embeddings, words):
48
  def main():
49
  st.title("BERT Embeddings Visualization")
50
 
51
- default_words = ["apple", "rocket", "philosophy"] # Default set of words
52
- state_key = "words"
53
- if state_key not in st.session_state:
54
- st.session_state[state_key] = default_words
 
 
 
55
 
56
  if st.button("Reset to Default Words"):
57
- st.session_state[state_key] = default_words[:]
 
 
 
 
58
  st.experimental_rerun()
59
 
60
  new_word = st.text_input("Enter a new word or phrase:")
61
  if st.button("Add Word/Phrase"):
62
  if new_word:
63
- st.session_state[state_key].append(new_word)
 
 
64
  st.experimental_rerun()
65
 
66
- words = st.session_state[state_key]
67
- embeddings = get_bert_embeddings(words)
 
 
 
 
 
 
 
 
68
  plot_interactive_bert_embeddings(embeddings, words)
69
 
70
  if __name__ == "__main__":
 
3
  from sklearn.decomposition import PCA
4
  import plotly.graph_objs as go
5
  import numpy as np
6
+ from database_utils import init_db, save_embeddings_to_db, get_all_embeddings, clear_all_entries, fetch_data_as_csv
7
 
8
+ # Initialize BERT model and tokenizer
9
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
  model = AutoModel.from_pretrained('bert-base-uncased')
11
 
 
14
  for word in words:
15
  inputs = tokenizer(word, return_tensors='pt')
16
  outputs = model(**inputs)
17
+ mean_embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy()
18
+ embeddings.append(mean_embedding[0])
19
  return embeddings
20
 
21
  def plot_interactive_bert_embeddings(embeddings, words):
 
49
  def main():
50
  st.title("BERT Embeddings Visualization")
51
 
52
+ default_words = ["apple", "rocket", "philosophy"]
53
+ if "words" not in st.session_state:
54
+ st.session_state.words = default_words
55
+ init_db() # Initialize the database
56
+ for word in default_words:
57
+ embedding = get_bert_embeddings([word])[0]
58
+ save_embeddings_to_db(word, embedding)
59
 
60
  if st.button("Reset to Default Words"):
61
+ clear_all_entries()
62
+ st.session_state.words = default_words[:]
63
+ for word in default_words:
64
+ embedding = get_bert_embeddings([word])[0]
65
+ save_embeddings_to_db(word, embedding)
66
  st.experimental_rerun()
67
 
68
  new_word = st.text_input("Enter a new word or phrase:")
69
  if st.button("Add Word/Phrase"):
70
  if new_word:
71
+ embedding = get_bert_embeddings([new_word])[0]
72
+ save_embeddings_to_db(new_word, embedding)
73
+ st.session_state.words.append(new_word)
74
  st.experimental_rerun()
75
 
76
+ if st.button("Clear All Entries"):
77
+ clear_all_entries()
78
+ st.session_state.words = default_words[:]
79
+ st.experimental_rerun()
80
+
81
+ if st.button("Download Database as CSV"):
82
+ csv = fetch_data_as_csv()
83
+ st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')
84
+
85
+ embeddings, words = get_all_embeddings()
86
  plot_interactive_bert_embeddings(embeddings, words)
87
 
88
  if __name__ == "__main__":