Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,8 +15,8 @@ def get_bert_embeddings(words):
|
|
| 15 |
inputs = tokenizer(word, return_tensors='pt')
|
| 16 |
outputs = model(**inputs)
|
| 17 |
mean_embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy()
|
| 18 |
-
embeddings.append(mean_embedding[0])
|
| 19 |
-
return embeddings
|
| 20 |
|
| 21 |
def plot_interactive_bert_embeddings(embeddings, words):
|
| 22 |
if len(words) >= 3: # Ensure there are at least 3 words for 3D PCA
|
|
@@ -28,7 +28,7 @@ def plot_interactive_bert_embeddings(embeddings, words):
|
|
| 28 |
y=[emb[1]],
|
| 29 |
z=[emb[2]],
|
| 30 |
mode='markers+text',
|
| 31 |
-
text=word,
|
| 32 |
name=word
|
| 33 |
) for emb, word in zip(reduced_embeddings, words)
|
| 34 |
], layout=go.Layout(
|
|
@@ -51,7 +51,7 @@ def main():
|
|
| 51 |
|
| 52 |
default_words = ["apple", "rocket", "philosophy"]
|
| 53 |
if "words" not in st.session_state:
|
| 54 |
-
st.session_state.words = default_words
|
| 55 |
init_db() # Initialize the database
|
| 56 |
for word in default_words:
|
| 57 |
embedding = get_bert_embeddings([word])[0]
|
|
@@ -83,7 +83,9 @@ def main():
|
|
| 83 |
st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')
|
| 84 |
|
| 85 |
embeddings, words = get_all_embeddings()
|
| 86 |
-
|
|
|
|
|
|
|
| 87 |
|
| 88 |
if __name__ == "__main__":
|
| 89 |
main()
|
|
|
|
| 15 |
inputs = tokenizer(word, return_tensors='pt')
|
| 16 |
outputs = model(**inputs)
|
| 17 |
mean_embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy()
|
| 18 |
+
embeddings.append(mean_embedding[0]) # Append the 1D embedding
|
| 19 |
+
return np.array(embeddings)
|
| 20 |
|
| 21 |
def plot_interactive_bert_embeddings(embeddings, words):
|
| 22 |
if len(words) >= 3: # Ensure there are at least 3 words for 3D PCA
|
|
|
|
| 28 |
y=[emb[1]],
|
| 29 |
z=[emb[2]],
|
| 30 |
mode='markers+text',
|
| 31 |
+
text=[word],
|
| 32 |
name=word
|
| 33 |
) for emb, word in zip(reduced_embeddings, words)
|
| 34 |
], layout=go.Layout(
|
|
|
|
| 51 |
|
| 52 |
default_words = ["apple", "rocket", "philosophy"]
|
| 53 |
if "words" not in st.session_state:
|
| 54 |
+
st.session_state.words = default_words[:]
|
| 55 |
init_db() # Initialize the database
|
| 56 |
for word in default_words:
|
| 57 |
embedding = get_bert_embeddings([word])[0]
|
|
|
|
| 83 |
st.download_button(label="Download CSV", data=csv, file_name='embeddings.csv', mime='text/csv')
|
| 84 |
|
| 85 |
embeddings, words = get_all_embeddings()
|
| 86 |
+
embeddings = np.array(embeddings)
|
| 87 |
+
if embeddings.size > 0:
|
| 88 |
+
plot_interactive_bert_embeddings(embeddings, words)
|
| 89 |
|
| 90 |
if __name__ == "__main__":
|
| 91 |
main()
|