rianders commited on
Commit
77b117f
·
verified ·
1 Parent(s): 26a46c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -41
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
 
2
  from transformers import BertModel, BertTokenizer
3
- import torch
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
6
  import numpy as np
@@ -10,40 +10,34 @@ def get_bert_embeddings(words):
10
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
11
  model = BertModel.from_pretrained('bert-base-uncased')
12
  embeddings = []
13
-
14
  for word in words:
15
  inputs = tokenizer(word, return_tensors='pt')
16
  outputs = model(**inputs)
17
- # Calculate mean of embeddings across all tokens in the phrase
18
  mean_embedding = outputs.last_hidden_state[0].mean(dim=0).detach().numpy()
19
  embeddings.append(mean_embedding)
20
-
21
  if len(embeddings) > 0:
22
  pca = PCA(n_components=3)
23
  reduced_embeddings = pca.fit_transform(np.array(embeddings))
24
  return reduced_embeddings
25
  return []
26
 
27
-
28
-
29
  # Plotly plotting function
30
  def plot_interactive_bert_embeddings(embeddings, words):
31
  if len(words) < 4:
32
  st.error("Please provide at least 4 words/phrases for effective visualization.")
33
  return None
34
-
35
  data = []
36
  for i, word in enumerate(words):
37
  trace = go.Scatter3d(
38
- x=[embeddings[i][0]],
39
- y=[embeddings[i][1]],
40
  z=[embeddings[i][2]],
41
  mode='markers+text',
42
  text=[word],
43
  name=word
44
  )
45
  data.append(trace)
46
-
47
  layout = go.Layout(
48
  title='3D Scatter Plot of BERT Embeddings',
49
  scene=dict(
@@ -55,41 +49,32 @@ def plot_interactive_bert_embeddings(embeddings, words):
55
  width=800,
56
  height=600
57
  )
58
-
59
  fig = go.Figure(data=data, layout=layout)
60
  return fig
61
 
62
  def main():
63
- st.title("BERT Embeddings Visualization")
64
-
65
- # Initialize or get existing words list from the session state
66
- if 'words' not in st.session_state:
67
- st.session_state.words = []
68
-
69
- # Text input for new words
70
- new_words_input = st.text_input("Enter a new word/phrase:")
71
-
72
- # Button to add new words
73
- if st.button("Add Word/Phrase"):
74
- if new_words_input:
75
- st.session_state.words.append(new_words_input)
76
- st.success(f"Added: {new_words_input}")
77
-
78
- # Display current list of words
79
- if st.session_state.words:
80
- st.write("Current list of words/phrases:", ', '.join(st.session_state.words))
81
-
82
- # Generate embeddings and plot
83
- if st.button("Generate Embeddings"):
84
- with st.spinner('Generating embeddings...'):
85
- embeddings = get_bert_embeddings(st.session_state.words)
86
- fig = plot_interactive_bert_embeddings(embeddings, st.session_state.words)
87
- if fig is not None:
88
- st.plotly_chart(fig, use_container_width=True)
89
-
90
- # Reset button
91
- if st.button("Reset"):
92
- st.session_state.words = []
93
 
94
  if __name__ == "__main__":
95
  main()
 
1
  import streamlit as st
2
+ from database_utils import init_db, save_embeddings_to_db, get_all_embeddings
3
  from transformers import BertModel, BertTokenizer
 
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
6
  import numpy as np
 
10
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
11
  model = BertModel.from_pretrained('bert-base-uncased')
12
  embeddings = []
 
13
  for word in words:
14
  inputs = tokenizer(word, return_tensors='pt')
15
  outputs = model(**inputs)
16
+ # Average the embeddings of all tokens for the word/phrase
17
  mean_embedding = outputs.last_hidden_state[0].mean(dim=0).detach().numpy()
18
  embeddings.append(mean_embedding)
 
19
  if len(embeddings) > 0:
20
  pca = PCA(n_components=3)
21
  reduced_embeddings = pca.fit_transform(np.array(embeddings))
22
  return reduced_embeddings
23
  return []
24
 
 
 
25
  # Plotly plotting function
26
  def plot_interactive_bert_embeddings(embeddings, words):
27
  if len(words) < 4:
28
  st.error("Please provide at least 4 words/phrases for effective visualization.")
29
  return None
 
30
  data = []
31
  for i, word in enumerate(words):
32
  trace = go.Scatter3d(
33
+ x=[embeddings[i][0]],
34
+ y=[embeddings[i][1]],
35
  z=[embeddings[i][2]],
36
  mode='markers+text',
37
  text=[word],
38
  name=word
39
  )
40
  data.append(trace)
 
41
  layout = go.Layout(
42
  title='3D Scatter Plot of BERT Embeddings',
43
  scene=dict(
 
49
  width=800,
50
  height=600
51
  )
 
52
  fig = go.Figure(data=data, layout=layout)
53
  return fig
54
 
55
  def main():
56
+ st.title("BERT Embeddings Visualization - Community Edition")
57
+
58
+ # Button to initialize the database
59
+ if st.button("Initialize Database"):
60
+ msg = init_db()
61
+ st.success(msg)
62
+
63
+ # Text input for new sentence
64
+ new_sentence = st.text_input("Enter a new sentence:")
65
+ if st.button("Add and Visualize Sentence"):
66
+ if new_sentence:
67
+ embeddings = get_bert_embeddings([new_sentence])
68
+ if embeddings.size > 0:
69
+ save_embeddings_to_db(new_sentence, embeddings[0])
70
+ st.success("Sentence added and embedding saved!")
71
+
72
+ # Button to display all embeddings
73
+ if st.button("Show All Embeddings"):
74
+ embeddings, sentences = get_all_embeddings()
75
+ fig = plot_interactive_bert_embeddings(np.vstack(embeddings), sentences)
76
+ if fig is not None:
77
+ st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
78
 
79
  if __name__ == "__main__":
80
  main()