import streamlit as st from transformers import AutoTokenizer, AutoModel from sklearn.decomposition import PCA import torch import numpy as np import plotly.graph_objects as go import math # Load transformer model + tokenizer.. @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") return tokenizer, model tokenizer, model = load_model() # Encode using mean pooling def encode_texts(texts): with torch.no_grad(): inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") output = model(**inputs) mask = inputs["attention_mask"].unsqueeze(-1).expand(output.last_hidden_state.shape).float() pooled = torch.sum(output.last_hidden_state * mask, dim=1) / mask.sum(dim=1) return pooled.cpu().numpy() # Session state init if "submitted_text" not in st.session_state: st.session_state.submitted_text = """BMW Porsche Mercedes Coffee Tea Water Germany Italy Brazil Violin Drums Trumpet Man Women Child""" # UI layout col1, col2 = st.columns([1, 3]) with col1: st.title("🧠 Embedding Input") with st.form(key="embedding_input_form"): st.form_submit_button("✅ Submit Text") st.text_area( label="Enter words (one per line)", key="submitted_text", height=400, ) texts = [t.strip() for t in st.session_state.submitted_text.split("\n") if t.strip()] if len(texts) < 3: st.warning("Please enter at least three words.") st.stop() embeddings = encode_texts(texts) coords = PCA(n_components=3).fit_transform(embeddings) # Rotation frames frames = [] for angle in range(0, 360, 2): rad = math.radians(angle) camera = dict(eye=dict(x=2 * math.cos(rad), y=2 * math.sin(rad), z=0.7)) frames.append(go.Frame(layout=dict(scene_camera=camera))) # Plotly figure with animation controls fig = go.Figure( data=[ go.Scatter3d( x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], mode="markers+text", text=texts, textposition="top center", textfont=dict(color="black"), marker=dict(size=6), ) ], layout=go.Layout( title="3D Embedding Projection", scene=dict( xaxis=dict(title="X", showbackground=True, backgroundcolor="rgba(255,0,0,0.4)"), yaxis=dict(title="Y", showbackground=True, backgroundcolor="rgba(0,255,0,0.4)"), zaxis=dict(title="Z", showbackground=True, backgroundcolor="rgba(0,0,255,0.4)"), ), updatemenus=[ dict( type="buttons", showactive=False, buttons=[ dict( label="🔄 Rotate", method="animate", args=[ None, dict( frame=dict(duration=50, redraw=True), transition=dict(duration=0), fromcurrent=True, mode="immediate" ) ], ) ], x=0.05, y=0.9 ) ], margin=dict(l=0, r=0, b=0, t=30), ), frames=frames ) with col2: st.title("📊 3D Plot") st.plotly_chart(fig, use_container_width=True)