rianders commited on
Commit
3e8771c
·
verified ·
1 Parent(s): 80f1785

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -46
app.py CHANGED
@@ -1,13 +1,12 @@
1
  import streamlit as st
2
- from database_utils import init_db, save_embeddings_to_db, get_all_embeddings, clear_all_entries
3
- from transformers import BertModel, BertTokenizer
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
6
  import numpy as np
7
 
8
- # Initialize and load the BERT model and tokenizer
9
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
- model = BertModel.from_pretrained('bert-base-uncased')
11
 
12
  def get_bert_embeddings(words):
13
  embeddings = []
@@ -19,57 +18,54 @@ def get_bert_embeddings(words):
19
  return embeddings
20
 
21
  def plot_interactive_bert_embeddings(embeddings, words):
22
- pca = PCA(n_components=3)
23
- reduced_embeddings = pca.fit_transform(embeddings)
24
- fig = go.Figure(data=[
25
- go.Scatter3d(
26
- x=[emb[0]],
27
- y=[emb[1]],
28
- z=[emb[2]],
29
- mode='markers+text',
30
- text=word,
31
- name=word
32
- ) for emb, word in zip(reduced_embeddings, words)
33
- ], layout=go.Layout(
34
- title='3D Scatter Plot of BERT Embeddings',
35
- scene=dict(
36
- xaxis=dict(title='PCA Component 1'),
37
- yaxis=dict(title='PCA Component 2'),
38
- zaxis=dict(title='PCA Component 3')
39
- ),
40
- autosize=False,
41
- width=800,
42
- height=600
43
- ))
44
- st.plotly_chart(fig, use_container_width=True)
 
 
 
45
 
46
  def main():
47
  st.title("BERT Embeddings Visualization")
48
- init_db()
49
 
50
- # Default starter words
51
- default_words = ["apple", "rocket", "philosophy"]
52
- # Load and plot default words if database is empty
53
- if not get_all_embeddings():
54
- embeddings = get_bert_embeddings(default_words)
55
- for word, emb in zip(default_words, embeddings):
56
- save_embeddings_to_db(word, emb)
57
- plot_interactive_bert_embeddings(embeddings, default_words)
58
 
59
  new_word = st.text_input("Enter a new word or phrase:")
60
  if st.button("Add Word/Phrase"):
61
  if new_word:
62
- embedding = get_bert_embeddings([new_word])[0]
63
- save_embeddings_to_db(new_word, embedding)
64
- embeddings, words = get_all_embeddings()
65
- plot_interactive_bert_embeddings(embeddings, words)
66
 
67
- if st.button("Reset to Default Words"):
68
- clear_all_entries()
69
- embeddings = get_bert_embeddings(default_words)
70
- for word, emb in zip(default_words, embeddings):
71
- save_embeddings_to_db(word, emb)
72
- plot_interactive_bert_embeddings(embeddings, default_words)
73
 
74
  if __name__ == "__main__":
75
  main()
 
1
  import streamlit as st
2
+ from transformers import AutoModel, BertTokenizer
 
3
  from sklearn.decomposition import PCA
4
  import plotly.graph_objs as go
5
  import numpy as np
6
 
7
+ # Initialize tokenizer and model globally if possible to avoid reloading it on every function call
8
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
9
+ model = AutoModel.from_pretrained('bert-base-uncased')
10
 
11
  def get_bert_embeddings(words):
12
  embeddings = []
 
18
  return embeddings
19
 
20
  def plot_interactive_bert_embeddings(embeddings, words):
21
+ if len(words) >= 3: # Ensure there are at least 3 words for 3D PCA
22
+ pca = PCA(n_components=3)
23
+ reduced_embeddings = pca.fit_transform(embeddings)
24
+ fig = go.Figure(data=[
25
+ go.Scatter3d(
26
+ x=[emb[0]],
27
+ y=[emb[1]],
28
+ z=[emb[2]],
29
+ mode='markers+text',
30
+ text=word,
31
+ name=word
32
+ ) for emb, word in zip(reduced_embeddings, words)
33
+ ], layout=go.Layout(
34
+ title='3D Scatter Plot of BERT Embeddings',
35
+ scene=dict(
36
+ xaxis=dict(title='PCA Component 1'),
37
+ yaxis=dict(title='PCA Component 2'),
38
+ zaxis=dict(title='PCA Component 3')
39
+ ),
40
+ autosize=False,
41
+ width=800,
42
+ height=600
43
+ ))
44
+ st.plotly_chart(fig, use_container_width=True)
45
+ else:
46
+ st.error("Please add more words to visualize. A minimum of three is required.")
47
 
48
  def main():
49
  st.title("BERT Embeddings Visualization")
 
50
 
51
+ default_words = ["apple", "rocket", "philosophy"] # Default set of words
52
+ state_key = "words"
53
+ if state_key not in st.session_state:
54
+ st.session_state[state_key] = default_words
55
+
56
+ if st.button("Reset to Default Words"):
57
+ st.session_state[state_key] = default_words[:]
58
+ st.experimental_rerun()
59
 
60
  new_word = st.text_input("Enter a new word or phrase:")
61
  if st.button("Add Word/Phrase"):
62
  if new_word:
63
+ st.session_state[state_key].append(new_word)
64
+ st.experimental_rerun()
 
 
65
 
66
+ words = st.session_state[state_key]
67
+ embeddings = get_bert_embeddings(words)
68
+ plot_interactive_bert_embeddings(embeddings, words)
 
 
 
69
 
70
  if __name__ == "__main__":
71
  main()