rianders commited on
Commit
78f2519
·
verified ·
1 Parent(s): e2d7fb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -50
app.py CHANGED
@@ -1,81 +1,109 @@
1
  import streamlit as st
2
- from transformers import BertModel, BertTokenizer
3
  from sklearn.decomposition import PCA
4
  import plotly.graph_objs as go
5
  import numpy as np
6
  from database_utils import init_db, save_embeddings_to_db, get_all_embeddings, clear_all_entries, fetch_data_as_csv
7
 
8
- # Initialize BERT model and tokenizer
9
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
- model = BertModel.from_pretrained('bert-base-uncased')
 
 
 
 
 
 
11
 
12
- def get_bert_embeddings(words):
13
  embeddings = []
14
  for word in words:
15
- inputs = tokenizer(word, return_tensors='pt')
16
  outputs = model(**inputs)
17
  mean_embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy()
18
- embeddings.append(mean_embedding[0]) # Append the 1D embedding
19
  return np.array(embeddings)
20
 
21
- def plot_interactive_bert_embeddings(embeddings, words):
22
- if len(words) >= 3: # Ensure there are at least 3 words for 3D PCA
23
- pca = PCA(n_components=3)
24
  reduced_embeddings = pca.fit_transform(embeddings)
25
- fig = go.Figure(data=[
26
- go.Scatter3d(
27
- x=[emb[0]],
28
- y=[emb[1]],
29
- z=[emb[2]],
30
- mode='markers+text',
31
- text=[word],
32
- name=word
33
- ) for emb, word in zip(reduced_embeddings, words)
34
- ], layout=go.Layout(
35
- title='3D Scatter Plot of BERT Embeddings',
36
- scene=dict(
37
- xaxis=dict(title='PCA Component 1'),
38
- yaxis=dict(title='PCA Component 2'),
39
- zaxis=dict(title='PCA Component 3')
40
- ),
41
- autosize=False,
42
- width=800,
43
- height=600
44
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  st.plotly_chart(fig, use_container_width=True)
46
  else:
47
- st.error("Please add more words to visualize. A minimum of three is required.")
48
 
49
  def main():
50
- st.title("BERT Embeddings Visualization")
51
 
52
- default_words = ["apple", "rocket", "philosophy"]
53
- if "words" not in st.session_state:
54
- st.session_state.words = default_words[:]
55
- init_db() # Initialize the database
56
- for word in default_words:
57
- embedding = get_bert_embeddings([word])[0]
58
- save_embeddings_to_db(word, embedding)
59
 
60
- if st.button("Reset to Default Words"):
 
 
 
 
 
 
 
 
 
61
  clear_all_entries()
62
- st.session_state.words = default_words[:]
63
- for word in default_words:
64
- embedding = get_bert_embeddings([word])[0]
65
- save_embeddings_to_db(word, embedding)
66
- st.experimental_rerun()
67
 
68
- new_word = st.text_input("Enter a new word or phrase:")
69
  if st.button("Add Word/Phrase"):
70
  if new_word:
71
- embedding = get_bert_embeddings([new_word])[0]
72
  save_embeddings_to_db(new_word, embedding)
73
  st.session_state.words.append(new_word)
74
  st.experimental_rerun()
75
 
76
  if st.button("Clear All Entries"):
77
  clear_all_entries()
78
- st.session_state.words = default_words[:]
 
 
79
  st.experimental_rerun()
80
 
81
  if st.button("Download Database as CSV"):
@@ -85,7 +113,7 @@ def main():
85
  embeddings, words = get_all_embeddings()
86
  if len(embeddings) > 0:
87
  embeddings = np.array(embeddings)
88
- plot_interactive_bert_embeddings(embeddings, words)
89
 
90
  if __name__ == "__main__":
91
- main()
 
1
  import streamlit as st
2
+ from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer
3
  from sklearn.decomposition import PCA
4
  import plotly.graph_objs as go
5
  import numpy as np
6
  from database_utils import init_db, save_embeddings_to_db, get_all_embeddings, clear_all_entries, fetch_data_as_csv
7
 
8
+ @st.cache_resource
9
+ def load_model(model_name):
10
+ if model_name == "BERT":
11
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
12
+ model = BertModel.from_pretrained('bert-base-uncased')
13
+ elif model_name == "RoBERTa":
14
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
15
+ model = RobertaModel.from_pretrained('roberta-base')
16
+ return tokenizer, model
17
 
18
+ def get_embeddings(words, tokenizer, model):
19
  embeddings = []
20
  for word in words:
21
+ inputs = tokenizer(word, return_tensors='pt', padding=True, truncation=True)
22
  outputs = model(**inputs)
23
  mean_embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy()
24
+ embeddings.append(mean_embedding[0])
25
  return np.array(embeddings)
26
 
27
+ def plot_interactive_embeddings(embeddings, words):
28
+ if len(words) >= 2:
29
+ pca = PCA(n_components=min(3, len(words)))
30
  reduced_embeddings = pca.fit_transform(embeddings)
31
+
32
+ if len(words) == 2:
33
+ fig = go.Figure(data=[
34
+ go.Scatter(
35
+ x=[emb[0]],
36
+ y=[emb[1]],
37
+ mode='markers+text',
38
+ text=[word],
39
+ name=word
40
+ ) for emb, word in zip(reduced_embeddings, words)
41
+ ])
42
+ fig.update_layout(
43
+ title='2D Scatter Plot of Embeddings',
44
+ xaxis_title='PCA Component 1',
45
+ yaxis_title='PCA Component 2'
46
+ )
47
+ else:
48
+ fig = go.Figure(data=[
49
+ go.Scatter3d(
50
+ x=[emb[0]],
51
+ y=[emb[1]],
52
+ z=[emb[2]],
53
+ mode='markers+text',
54
+ text=[word],
55
+ name=word
56
+ ) for emb, word in zip(reduced_embeddings, words)
57
+ ])
58
+ fig.update_layout(
59
+ title='3D Scatter Plot of Embeddings',
60
+ scene=dict(
61
+ xaxis_title='PCA Component 1',
62
+ yaxis_title='PCA Component 2',
63
+ zaxis_title='PCA Component 3'
64
+ )
65
+ )
66
+
67
+ fig.update_layout(autosize=False, width=800, height=600)
68
  st.plotly_chart(fig, use_container_width=True)
69
  else:
70
+ st.error("Please add at least one more word to visualize.")
71
 
72
  def main():
73
+ st.title("Language Model Embeddings Visualization")
74
 
75
+ model_choice = st.selectbox("Choose a model:", ["BERT", "RoBERTa"])
76
+ tokenizer, model = load_model(model_choice)
 
 
 
 
 
77
 
78
+ default_word = "example"
79
+ if "words" not in st.session_state or "model" not in st.session_state:
80
+ st.session_state.words = [default_word]
81
+ st.session_state.model = model_choice
82
+ init_db()
83
+ embedding = get_embeddings([default_word], tokenizer, model)[0]
84
+ save_embeddings_to_db(default_word, embedding)
85
+ elif st.session_state.model != model_choice:
86
+ st.session_state.words = [default_word]
87
+ st.session_state.model = model_choice
88
  clear_all_entries()
89
+ embedding = get_embeddings([default_word], tokenizer, model)[0]
90
+ save_embeddings_to_db(default_word, embedding)
91
+
92
+ st.write(f"Current words ({model_choice}):", ", ".join(st.session_state.words))
 
93
 
94
+ new_word = st.text_input("Enter a new word or phrase:", "")
95
  if st.button("Add Word/Phrase"):
96
  if new_word:
97
+ embedding = get_embeddings([new_word], tokenizer, model)[0]
98
  save_embeddings_to_db(new_word, embedding)
99
  st.session_state.words.append(new_word)
100
  st.experimental_rerun()
101
 
102
  if st.button("Clear All Entries"):
103
  clear_all_entries()
104
+ st.session_state.words = [default_word]
105
+ embedding = get_embeddings([default_word], tokenizer, model)[0]
106
+ save_embeddings_to_db(default_word, embedding)
107
  st.experimental_rerun()
108
 
109
  if st.button("Download Database as CSV"):
 
113
  embeddings, words = get_all_embeddings()
114
  if len(embeddings) > 0:
115
  embeddings = np.array(embeddings)
116
+ plot_interactive_embeddings(embeddings, words)
117
 
118
  if __name__ == "__main__":
119
+ main()