viewembedding / app.py
rianders's picture
Update app.py
edbfc19
raw
history blame
1.71 kB
import streamlit as st
from transformers import BertModel, BertTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.graph_objs as go
# BERT Embeddings
def plot_interactive_bert_embeddings(embeddings, words):
if len(words) < 4:
st.error("Please provide at least 4 words/phrases for effective visualization.")
return None
data = []
for i, word in enumerate(words):
trace = go.Scatter3d(
x=[embeddings[i][0]],
y=[embeddings[i][1]],
z=[embeddings[i][2]],
mode='markers+text',
text=[word],
name=word
)
data.append(trace)
layout = go.Layout(
title='3D Scatter Plot of BERT Embeddings',
scene=dict(
xaxis=dict(title='PCA Component 1'),
yaxis=dict(title='PCA Component 2'),
zaxis=dict(title='PCA Component 3')
),
autosize=False,
width=800, # Width of the plot
height=600 # Height of the plot
)
fig = go.Figure(data=data, layout=layout)
return fig
# Streamlit app
def main():
st.title("BERT Embeddings Visualization")
# Text input for words
words_input = st.text_area("Enter words/phrases separated by commas:", "Spider-Man, Rocket Racoon, Venom, Spider, Racoon, Snake")
words = [word.strip() for word in words_input.split(',')]
if st.button("Generate Embeddings"):
with st.spinner('Generating embeddings...'):
embeddings = get_bert_embeddings(words)
fig = plot_interactive_bert_embeddings(embeddings, words)
st.plotly_chart(fig, use_container_width=True)
if __name__ == "__main__":
main()