berndf's picture
restored
423ac96 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModel
from sklearn.decomposition import PCA
import torch
import numpy as np
import plotly.graph_objects as go
import math
# Load transformer model + tokenizer..
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
return tokenizer, model
tokenizer, model = load_model()
# Encode using mean pooling
def encode_texts(texts):
with torch.no_grad():
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
output = model(**inputs)
mask = inputs["attention_mask"].unsqueeze(-1).expand(output.last_hidden_state.shape).float()
pooled = torch.sum(output.last_hidden_state * mask, dim=1) / mask.sum(dim=1)
return pooled.cpu().numpy()
# Session state init
if "submitted_text" not in st.session_state:
st.session_state.submitted_text = """BMW
Porsche
Mercedes
Coffee
Tea
Water
Germany
Italy
Brazil
Violin
Drums
Trumpet
Man
Women
Child"""
# UI layout
col1, col2 = st.columns([1, 3])
with col1:
st.title("🧠 Embedding Input")
with st.form(key="embedding_input_form"):
st.form_submit_button("βœ… Submit Text")
st.text_area(
label="Enter words (one per line)",
key="submitted_text",
height=400,
)
texts = [t.strip() for t in st.session_state.submitted_text.split("\n") if t.strip()]
if len(texts) < 3:
st.warning("Please enter at least three words.")
st.stop()
embeddings = encode_texts(texts)
coords = PCA(n_components=3).fit_transform(embeddings)
# Rotation frames
frames = []
for angle in range(0, 360, 2):
rad = math.radians(angle)
camera = dict(eye=dict(x=2 * math.cos(rad), y=2 * math.sin(rad), z=0.7))
frames.append(go.Frame(layout=dict(scene_camera=camera)))
# Plotly figure with animation controls
fig = go.Figure(
data=[
go.Scatter3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
mode="markers+text",
text=texts,
textposition="top center",
textfont=dict(color="black"),
marker=dict(size=6),
)
],
layout=go.Layout(
title="3D Embedding Projection",
scene=dict(
xaxis=dict(title="X", showbackground=True, backgroundcolor="rgba(255,0,0,0.4)"),
yaxis=dict(title="Y", showbackground=True, backgroundcolor="rgba(0,255,0,0.4)"),
zaxis=dict(title="Z", showbackground=True, backgroundcolor="rgba(0,0,255,0.4)"),
),
updatemenus=[
dict(
type="buttons",
showactive=False,
buttons=[
dict(
label="πŸ”„ Rotate",
method="animate",
args=[
None,
dict(
frame=dict(duration=50, redraw=True),
transition=dict(duration=0),
fromcurrent=True,
mode="immediate"
)
],
)
],
x=0.05,
y=0.9
)
],
margin=dict(l=0, r=0, b=0, t=30),
),
frames=frames
)
with col2:
st.title("πŸ“Š 3D Plot")
st.plotly_chart(fig, use_container_width=True)