Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModel | |
| from sklearn.decomposition import PCA | |
| import torch | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import math | |
| # Load transformer model + tokenizer.. | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | |
| model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | |
| return tokenizer, model | |
| tokenizer, model = load_model() | |
| # Encode using mean pooling | |
| def encode_texts(texts): | |
| with torch.no_grad(): | |
| inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") | |
| output = model(**inputs) | |
| mask = inputs["attention_mask"].unsqueeze(-1).expand(output.last_hidden_state.shape).float() | |
| pooled = torch.sum(output.last_hidden_state * mask, dim=1) / mask.sum(dim=1) | |
| return pooled.cpu().numpy() | |
| # Session state init | |
| if "submitted_text" not in st.session_state: | |
| st.session_state.submitted_text = """BMW | |
| Porsche | |
| Mercedes | |
| Coffee | |
| Tea | |
| Water | |
| Germany | |
| Italy | |
| Brazil | |
| Violin | |
| Drums | |
| Trumpet | |
| Man | |
| Women | |
| Child""" | |
| # UI layout | |
| col1, col2 = st.columns([1, 3]) | |
| with col1: | |
| st.title("π§ Embedding Input") | |
| with st.form(key="embedding_input_form"): | |
| st.form_submit_button("β Submit Text") | |
| st.text_area( | |
| label="Enter words (one per line)", | |
| key="submitted_text", | |
| height=400, | |
| ) | |
| texts = [t.strip() for t in st.session_state.submitted_text.split("\n") if t.strip()] | |
| if len(texts) < 3: | |
| st.warning("Please enter at least three words.") | |
| st.stop() | |
| embeddings = encode_texts(texts) | |
| coords = PCA(n_components=3).fit_transform(embeddings) | |
| # Rotation frames | |
| frames = [] | |
| for angle in range(0, 360, 2): | |
| rad = math.radians(angle) | |
| camera = dict(eye=dict(x=2 * math.cos(rad), y=2 * math.sin(rad), z=0.7)) | |
| frames.append(go.Frame(layout=dict(scene_camera=camera))) | |
| # Plotly figure with animation controls | |
| fig = go.Figure( | |
| data=[ | |
| go.Scatter3d( | |
| x=coords[:, 0], | |
| y=coords[:, 1], | |
| z=coords[:, 2], | |
| mode="markers+text", | |
| text=texts, | |
| textposition="top center", | |
| textfont=dict(color="black"), | |
| marker=dict(size=6), | |
| ) | |
| ], | |
| layout=go.Layout( | |
| title="3D Embedding Projection", | |
| scene=dict( | |
| xaxis=dict(title="X", showbackground=True, backgroundcolor="rgba(255,0,0,0.4)"), | |
| yaxis=dict(title="Y", showbackground=True, backgroundcolor="rgba(0,255,0,0.4)"), | |
| zaxis=dict(title="Z", showbackground=True, backgroundcolor="rgba(0,0,255,0.4)"), | |
| ), | |
| updatemenus=[ | |
| dict( | |
| type="buttons", | |
| showactive=False, | |
| buttons=[ | |
| dict( | |
| label="π Rotate", | |
| method="animate", | |
| args=[ | |
| None, | |
| dict( | |
| frame=dict(duration=50, redraw=True), | |
| transition=dict(duration=0), | |
| fromcurrent=True, | |
| mode="immediate" | |
| ) | |
| ], | |
| ) | |
| ], | |
| x=0.05, | |
| y=0.9 | |
| ) | |
| ], | |
| margin=dict(l=0, r=0, b=0, t=30), | |
| ), | |
| frames=frames | |
| ) | |
| with col2: | |
| st.title("π 3D Plot") | |
| st.plotly_chart(fig, use_container_width=True) |