Spaces:
Runtime error
Runtime error
File size: 3,583 Bytes
08f0b45 5d075b2 6ec32ce 423ac96 5d075b2 423ac96 6ec32ce 423ac96 6ec32ce 5d075b2 423ac96 6449dfd 423ac96 6ec32ce 423ac96 6ec32ce 423ac96 6ec32ce 423ac96 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | import streamlit as st
from transformers import AutoTokenizer, AutoModel
from sklearn.decomposition import PCA
import torch
import numpy as np
import plotly.graph_objects as go
import math
# Load transformer model + tokenizer..
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
return tokenizer, model
tokenizer, model = load_model()
# Encode using mean pooling
def encode_texts(texts):
with torch.no_grad():
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
output = model(**inputs)
mask = inputs["attention_mask"].unsqueeze(-1).expand(output.last_hidden_state.shape).float()
pooled = torch.sum(output.last_hidden_state * mask, dim=1) / mask.sum(dim=1)
return pooled.cpu().numpy()
# Session state init
if "submitted_text" not in st.session_state:
st.session_state.submitted_text = """BMW
Porsche
Mercedes
Coffee
Tea
Water
Germany
Italy
Brazil
Violin
Drums
Trumpet
Man
Women
Child"""
# UI layout
col1, col2 = st.columns([1, 3])
with col1:
st.title("🧠 Embedding Input")
with st.form(key="embedding_input_form"):
st.form_submit_button("✅ Submit Text")
st.text_area(
label="Enter words (one per line)",
key="submitted_text",
height=400,
)
texts = [t.strip() for t in st.session_state.submitted_text.split("\n") if t.strip()]
if len(texts) < 3:
st.warning("Please enter at least three words.")
st.stop()
embeddings = encode_texts(texts)
coords = PCA(n_components=3).fit_transform(embeddings)
# Rotation frames
frames = []
for angle in range(0, 360, 2):
rad = math.radians(angle)
camera = dict(eye=dict(x=2 * math.cos(rad), y=2 * math.sin(rad), z=0.7))
frames.append(go.Frame(layout=dict(scene_camera=camera)))
# Plotly figure with animation controls
fig = go.Figure(
data=[
go.Scatter3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
mode="markers+text",
text=texts,
textposition="top center",
textfont=dict(color="black"),
marker=dict(size=6),
)
],
layout=go.Layout(
title="3D Embedding Projection",
scene=dict(
xaxis=dict(title="X", showbackground=True, backgroundcolor="rgba(255,0,0,0.4)"),
yaxis=dict(title="Y", showbackground=True, backgroundcolor="rgba(0,255,0,0.4)"),
zaxis=dict(title="Z", showbackground=True, backgroundcolor="rgba(0,0,255,0.4)"),
),
updatemenus=[
dict(
type="buttons",
showactive=False,
buttons=[
dict(
label="🔄 Rotate",
method="animate",
args=[
None,
dict(
frame=dict(duration=50, redraw=True),
transition=dict(duration=0),
fromcurrent=True,
mode="immediate"
)
],
)
],
x=0.05,
y=0.9
)
],
margin=dict(l=0, r=0, b=0, t=30),
),
frames=frames
)
with col2:
st.title("📊 3D Plot")
st.plotly_chart(fig, use_container_width=True) |