rianders commited on
Commit
f9982c8
·
verified ·
1 Parent(s): 0a0a961

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -42
app.py CHANGED
@@ -50,19 +50,55 @@ def plot_interactive_embeddings(embeddings, phrases):
50
  st.error("Please add at least one more phrase to visualize.")
51
 
52
  def main():
 
53
  st.title("Language Model Embeddings Visualization")
54
 
55
  st.markdown("""
56
  This application visualizes embeddings of words and phrases from BERT or RoBERTa language models.
57
- Here's how to use it:
58
- 1. Choose a model (BERT or RoBERTa) from the dropdown menu.
59
- 2. Enter words or phrases one at a time, or upload a CSV file with a 'phrase' column.
60
- 3. View the 2D or 3D plot of the embeddings.
61
- 4. Download the current database as a CSV file for later use.
62
  Explore how different words and phrases relate to each other in the embedding space!
63
  """)
64
 
65
- model_choice = st.selectbox("Choose a model:", ["BERT", "RoBERTa"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  tokenizer, model = load_model(model_choice)
67
 
68
  default_phrase = "example"
@@ -72,47 +108,15 @@ def main():
72
  embedding = get_embeddings([default_phrase], tokenizer, model)[0]
73
  save_embeddings_to_db(default_phrase, embedding)
74
 
75
- st.write(f"Current phrases ({model_choice}):", ", ".join(st.session_state.phrases))
76
-
77
- new_phrase = st.text_input("Enter a new word or phrase:", "")
78
- if st.button("Add Phrase"):
79
- if new_phrase and new_phrase not in st.session_state.phrases:
80
- embedding = get_embeddings([new_phrase], tokenizer, model)[0]
81
- save_embeddings_to_db(new_phrase, embedding)
82
- st.session_state.phrases.append(new_phrase)
83
- st.experimental_rerun()
84
-
85
- uploaded_file = st.file_uploader("Upload CSV file", type="csv")
86
- if uploaded_file is not None:
87
- df = pd.read_csv(uploaded_file)
88
- phrase_column = next((col for col in ['phrase', 'Phrase'] if col in df.columns), None)
89
- if phrase_column:
90
- new_phrases = df[phrase_column].dropna().unique().tolist()
91
- for phrase in new_phrases:
92
- if phrase and phrase not in st.session_state.phrases:
93
- embedding = get_embeddings([phrase], tokenizer, model)[0]
94
- save_embeddings_to_db(phrase, embedding)
95
- st.session_state.phrases.append(phrase)
96
- st.success(f"Successfully imported {len(new_phrases)} new phrases.")
97
- st.experimental_rerun()
98
- else:
99
- st.error("The CSV file must contain a 'phrase' or 'Phrase' column.")
100
-
101
- if st.button("Clear All Entries"):
102
- clear_all_entries()
103
- st.session_state.phrases = [default_phrase]
104
- embedding = get_embeddings([default_phrase], tokenizer, model)[0]
105
- save_embeddings_to_db(default_phrase, embedding)
106
- st.experimental_rerun()
107
-
108
- if st.button("Download Database as CSV"):
109
- csv = fetch_data_as_csv()
110
- st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')
111
 
112
  embeddings, phrases = get_all_embeddings()
113
  if len(embeddings) > 0:
114
  embeddings = np.array(embeddings)
115
  plot_interactive_embeddings(embeddings, phrases)
 
 
116
 
117
  if __name__ == "__main__":
118
  main()
 
50
  st.error("Please add at least one more phrase to visualize.")
51
 
52
  def main():
53
+ st.set_page_config(layout="wide")
54
  st.title("Language Model Embeddings Visualization")
55
 
56
  st.markdown("""
57
  This application visualizes embeddings of words and phrases from BERT or RoBERTa language models.
 
 
 
 
 
58
  Explore how different words and phrases relate to each other in the embedding space!
59
  """)
60
 
61
+ # Sidebar
62
+ with st.sidebar:
63
+ st.header("Controls")
64
+ model_choice = st.selectbox("Choose a model:", ["BERT", "RoBERTa"])
65
+
66
+ new_phrase = st.text_input("Enter a new word or phrase:", "")
67
+ if st.button("Add Phrase"):
68
+ if new_phrase and new_phrase not in st.session_state.phrases:
69
+ embedding = get_embeddings([new_phrase], tokenizer, model)[0]
70
+ save_embeddings_to_db(new_phrase, embedding)
71
+ st.session_state.phrases.append(new_phrase)
72
+ st.experimental_rerun()
73
+
74
+ uploaded_file = st.file_uploader("Upload CSV file", type="csv")
75
+ if uploaded_file is not None:
76
+ df = pd.read_csv(uploaded_file)
77
+ phrase_column = next((col for col in ['phrase', 'Phrase'] if col in df.columns), None)
78
+ if phrase_column:
79
+ new_phrases = df[phrase_column].dropna().unique().tolist()
80
+ for phrase in new_phrases:
81
+ if phrase and phrase not in st.session_state.phrases:
82
+ embedding = get_embeddings([phrase], tokenizer, model)[0]
83
+ save_embeddings_to_db(phrase, embedding)
84
+ st.session_state.phrases.append(phrase)
85
+ st.success(f"Successfully imported {len(new_phrases)} new phrases.")
86
+ st.experimental_rerun()
87
+ else:
88
+ st.error("The CSV file must contain a 'phrase' or 'Phrase' column.")
89
+
90
+ if st.button("Clear All Entries"):
91
+ clear_all_entries()
92
+ st.session_state.phrases = [default_phrase]
93
+ embedding = get_embeddings([default_phrase], tokenizer, model)[0]
94
+ save_embeddings_to_db(default_phrase, embedding)
95
+ st.experimental_rerun()
96
+
97
+ if st.button("Download Database as CSV"):
98
+ csv = fetch_data_as_csv()
99
+ st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')
100
+
101
+ # Main area
102
  tokenizer, model = load_model(model_choice)
103
 
104
  default_phrase = "example"
 
108
  embedding = get_embeddings([default_phrase], tokenizer, model)[0]
109
  save_embeddings_to_db(default_phrase, embedding)
110
 
111
+ st.subheader(f"Current phrases ({model_choice}):")
112
+ st.write(", ".join(st.session_state.phrases))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  embeddings, phrases = get_all_embeddings()
115
  if len(embeddings) > 0:
116
  embeddings = np.array(embeddings)
117
  plot_interactive_embeddings(embeddings, phrases)
118
+ else:
119
+ st.info("Add phrases using the sidebar to visualize their embeddings.")
120
 
121
  if __name__ == "__main__":
122
  main()