rianders commited on
Commit
80f1785
·
verified ·
1 Parent(s): 2a28595

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -51
app.py CHANGED
@@ -1,70 +1,75 @@
1
  import streamlit as st
2
- from database_utils import init_db, save_embeddings_to_db, get_all_embeddings, fetch_data_as_csv, clear_all_entries
3
  from transformers import BertModel, BertTokenizer
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
6
  import numpy as np
7
 
8
- # BERT embeddings function
9
- def get_bert_embeddings(word):
10
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
11
- model = BertModel.from_pretrained('bert-base-uncased')
12
- inputs = tokenizer(word, return_tensors='pt')
13
- outputs = model(**inputs)
14
- mean_embedding = outputs.last_hidden_state[0].mean(dim=0).detach().numpy()
15
- return mean_embedding # Return the mean embedding directly
16
-
17
-
18
- def plot_interactive_bert_embeddings():
19
- embeddings, sentences = get_all_embeddings()
20
- if len(sentences) > 0:
21
- # Even if there's less than 3, PCA can still run with min(n_samples, n_features)
22
- pca = PCA(n_components=min(3, len(sentences)))
23
- reduced_embeddings = pca.fit_transform(np.array(embeddings))
24
- fig = go.Figure(data=[
25
- go.Scatter3d(
26
- x=[emb[0]],
27
- y=[emb[1] if len(emb) > 1 else 0], # Ensure there are enough dimensions
28
- z=[emb[2] if len(emb) > 2 else 0],
29
- mode='markers+text',
30
- text=sent,
31
- name=sent
32
- ) for emb, sent in zip(reduced_embeddings, sentences)
33
- ], layout=go.Layout(
34
- title='3D Scatter Plot of BERT Embeddings',
35
- scene=dict(
36
- xaxis=dict(title='PCA Component 1'),
37
- yaxis=dict(title='PCA Component 2'),
38
- zaxis=dict(title='PCA Component 3')
39
- ),
40
- autosize=False,
41
- width=800,
42
- height=600
43
- ))
44
- st.plotly_chart(fig, use_container_width=True)
45
- else:
46
- st.error("No data available for visualization.")
47
 
 
 
 
 
 
 
 
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def main():
51
- st.title("BERT Embeddings Visualization - Community Edition")
52
-
53
  init_db()
54
-
 
 
 
 
 
 
 
 
 
55
  new_word = st.text_input("Enter a new word or phrase:")
56
  if st.button("Add Word/Phrase"):
57
  if new_word:
58
- embedding = get_bert_embeddings(new_word)
59
  save_embeddings_to_db(new_word, embedding)
60
- st.success(f"Added: {new_word}")
61
-
62
- if st.button("Visualize Embeddings"):
63
- plot_interactive_bert_embeddings()
64
 
65
- if st.button("Clear All Entries"):
66
- clear_all_entries() # This function needs to be defined in database_utils.py to delete all records
67
- st.success("All entries have been cleared.")
 
 
 
68
 
69
  if __name__ == "__main__":
70
  main()
 
1
  import streamlit as st
2
+ from database_utils import init_db, save_embeddings_to_db, get_all_embeddings, clear_all_entries
3
  from transformers import BertModel, BertTokenizer
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
6
  import numpy as np
7
 
8
+ # Initialize and load the BERT model and tokenizer
9
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
+ model = BertModel.from_pretrained('bert-base-uncased')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ def get_bert_embeddings(words):
13
+ embeddings = []
14
+ for word in words:
15
+ inputs = tokenizer(word, return_tensors='pt')
16
+ outputs = model(**inputs)
17
+ mean_embedding = outputs.last_hidden_state[0].mean(dim=0).detach().numpy()
18
+ embeddings.append(mean_embedding)
19
+ return embeddings
20
 
21
+ def plot_interactive_bert_embeddings(embeddings, words):
22
+ pca = PCA(n_components=3)
23
+ reduced_embeddings = pca.fit_transform(embeddings)
24
+ fig = go.Figure(data=[
25
+ go.Scatter3d(
26
+ x=[emb[0]],
27
+ y=[emb[1]],
28
+ z=[emb[2]],
29
+ mode='markers+text',
30
+ text=word,
31
+ name=word
32
+ ) for emb, word in zip(reduced_embeddings, words)
33
+ ], layout=go.Layout(
34
+ title='3D Scatter Plot of BERT Embeddings',
35
+ scene=dict(
36
+ xaxis=dict(title='PCA Component 1'),
37
+ yaxis=dict(title='PCA Component 2'),
38
+ zaxis=dict(title='PCA Component 3')
39
+ ),
40
+ autosize=False,
41
+ width=800,
42
+ height=600
43
+ ))
44
+ st.plotly_chart(fig, use_container_width=True)
45
 
46
  def main():
47
+ st.title("BERT Embeddings Visualization")
 
48
  init_db()
49
+
50
+ # Default starter words
51
+ default_words = ["apple", "rocket", "philosophy"]
52
+ # Load and plot default words if database is empty
53
+ if not get_all_embeddings():
54
+ embeddings = get_bert_embeddings(default_words)
55
+ for word, emb in zip(default_words, embeddings):
56
+ save_embeddings_to_db(word, emb)
57
+ plot_interactive_bert_embeddings(embeddings, default_words)
58
+
59
  new_word = st.text_input("Enter a new word or phrase:")
60
  if st.button("Add Word/Phrase"):
61
  if new_word:
62
+ embedding = get_bert_embeddings([new_word])[0]
63
  save_embeddings_to_db(new_word, embedding)
64
+ embeddings, words = get_all_embeddings()
65
+ plot_interactive_bert_embeddings(embeddings, words)
 
 
66
 
67
+ if st.button("Reset to Default Words"):
68
+ clear_all_entries()
69
+ embeddings = get_bert_embeddings(default_words)
70
+ for word, emb in zip(default_words, embeddings):
71
+ save_embeddings_to_db(word, emb)
72
+ plot_interactive_bert_embeddings(embeddings, default_words)
73
 
74
  if __name__ == "__main__":
75
  main()