rianders commited on
Commit
05c46d6
·
verified ·
1 Parent(s): 7be9f60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -22
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from database_utils import init_db, save_embeddings_to_db, get_all_embeddings
3
  from transformers import BertModel, BertTokenizer
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
@@ -13,7 +13,6 @@ def get_bert_embeddings(words):
13
  for word in words:
14
  inputs = tokenizer(word, return_tensors='pt')
15
  outputs = model(**inputs)
16
- # Average the embeddings of all tokens for the word/phrase
17
  mean_embedding = outputs.last_hidden_state[0].mean(dim=0).detach().numpy()
18
  embeddings.append(mean_embedding)
19
  if len(embeddings) > 0:
@@ -22,7 +21,6 @@ def get_bert_embeddings(words):
22
  return reduced_embeddings
23
  return []
24
 
25
- # Plotly plotting function
26
  def plot_interactive_bert_embeddings(embeddings, words):
27
  if len(words) < 4:
28
  st.error("Please provide at least 4 words/phrases for effective visualization.")
@@ -52,29 +50,36 @@ def plot_interactive_bert_embeddings(embeddings, words):
52
  fig = go.Figure(data=data, layout=layout)
53
  return fig
54
 
55
- def main():
56
- st.title("BERT Embeddings Visualization - Community Edition")
57
-
58
- # Button to initialize the database
59
- if st.button("Initialize Database"):
60
- msg = init_db()
61
- st.success(msg)
62
-
63
- # Text input for new sentence
64
- new_sentence = st.text_input("Enter a new sentence:")
65
- if st.button("Add and Visualize Sentence"):
66
- if new_sentence:
67
- embeddings = get_bert_embeddings([new_sentence])
68
- if embeddings.size > 0:
69
- save_embeddings_to_db(new_sentence, embeddings[0])
70
- st.success("Sentence added and embedding saved!")
71
-
72
- # Button to display all embeddings
73
- if st.button("Show All Embeddings"):
74
  embeddings, sentences = get_all_embeddings()
75
  fig = plot_interactive_bert_embeddings(np.vstack(embeddings), sentences)
76
  if fig is not None:
77
  st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  if __name__ == "__main__":
80
  main()
 
1
  import streamlit as st
2
+ from database_utils import init_db, save_embeddings_to_db, get_all_embeddings, fetch_data_as_csv
3
  from transformers import BertModel, BertTokenizer
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
 
13
  for word in words:
14
  inputs = tokenizer(word, return_tensors='pt')
15
  outputs = model(**inputs)
 
16
  mean_embedding = outputs.last_hidden_state[0].mean(dim=0).detach().numpy()
17
  embeddings.append(mean_embedding)
18
  if len(embeddings) > 0:
 
21
  return reduced_embeddings
22
  return []
23
 
 
24
  def plot_interactive_bert_embeddings(embeddings, words):
25
  if len(words) < 4:
26
  st.error("Please provide at least 4 words/phrases for effective visualization.")
 
50
  fig = go.Figure(data=data, layout=layout)
51
  return fig
52
 
53
+ def process_and_visualize_words(words):
54
+ embeddings = get_bert_embeddings(words)
55
+ if embeddings.size > 0:
56
+ for word, embedding in zip(words, embeddings):
57
+ save_embeddings_to_db(word, embedding)
58
+ st.success("Sentences added and embeddings saved!")
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  embeddings, sentences = get_all_embeddings()
60
  fig = plot_interactive_bert_embeddings(np.vstack(embeddings), sentences)
61
  if fig is not None:
62
  st.plotly_chart(fig, use_container_width=True)
63
+ else:
64
+ st.error("Failed to generate embeddings. Ensure your sentences are correctly formatted.")
65
+
66
+ def main():
67
+ st.title("BERT Embeddings Visualization - Community Edition")
68
+
69
+ init_db_message = init_db()
70
+ st.text(init_db_message)
71
+
72
+ new_sentences = st.text_input("Enter at least three words/phrases, comma-separated:", "Example: apple, banana, orange")
73
+ if st.button("Add and Visualize Sentences"):
74
+ words = [word.strip() for word in new_sentences.split(',')]
75
+ if len(words) < 3:
76
+ st.error("Please enter at least three words/phrases separated by commas.")
77
+ else:
78
+ process_and_visualize_words(words)
79
+
80
+ if st.button("Download Database as CSV"):
81
+ csv = fetch_data_as_csv()
82
+ st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')
83
 
84
  if __name__ == "__main__":
85
  main()