rianders commited on
Commit
fb57cc2
·
verified ·
1 Parent(s): 6ef9d55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -15,8 +15,8 @@ def get_bert_embeddings(words):
15
  inputs = tokenizer(word, return_tensors='pt')
16
  outputs = model(**inputs)
17
  mean_embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy()
18
- embeddings.append(mean_embedding[0])
19
- return embeddings
20
 
21
  def plot_interactive_bert_embeddings(embeddings, words):
22
  if len(words) >= 3: # Ensure there are at least 3 words for 3D PCA
@@ -28,7 +28,7 @@ def plot_interactive_bert_embeddings(embeddings, words):
28
  y=[emb[1]],
29
  z=[emb[2]],
30
  mode='markers+text',
31
- text=word,
32
  name=word
33
  ) for emb, word in zip(reduced_embeddings, words)
34
  ], layout=go.Layout(
@@ -51,7 +51,7 @@ def main():
51
 
52
  default_words = ["apple", "rocket", "philosophy"]
53
  if "words" not in st.session_state:
54
- st.session_state.words = default_words
55
  init_db() # Initialize the database
56
  for word in default_words:
57
  embedding = get_bert_embeddings([word])[0]
@@ -83,7 +83,9 @@ def main():
83
  st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')
84
 
85
  embeddings, words = get_all_embeddings()
86
- plot_interactive_bert_embeddings(embeddings, words)
 
 
87
 
88
  if __name__ == "__main__":
89
  main()
 
15
  inputs = tokenizer(word, return_tensors='pt')
16
  outputs = model(**inputs)
17
  mean_embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy()
18
+ embeddings.append(mean_embedding[0]) # Append the 1D embedding
19
+ return np.array(embeddings)
20
 
21
  def plot_interactive_bert_embeddings(embeddings, words):
22
  if len(words) >= 3: # Ensure there are at least 3 words for 3D PCA
 
28
  y=[emb[1]],
29
  z=[emb[2]],
30
  mode='markers+text',
31
+ text=[word],
32
  name=word
33
  ) for emb, word in zip(reduced_embeddings, words)
34
  ], layout=go.Layout(
 
51
 
52
  default_words = ["apple", "rocket", "philosophy"]
53
  if "words" not in st.session_state:
54
+ st.session_state.words = default_words[:]
55
  init_db() # Initialize the database
56
  for word in default_words:
57
  embedding = get_bert_embeddings([word])[0]
 
83
  st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')
84
 
85
  embeddings, words = get_all_embeddings()
86
+ embeddings = np.array(embeddings)
87
+ if embeddings.size > 0:
88
+ plot_interactive_bert_embeddings(embeddings, words)
89
 
90
  if __name__ == "__main__":
91
  main()