File size: 3,583 Bytes
08f0b45
5d075b2
6ec32ce
423ac96
 
 
 
5d075b2
423ac96
6ec32ce
423ac96
 
 
6ec32ce
5d075b2
423ac96
6449dfd
423ac96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ec32ce
423ac96
 
 
 
 
 
 
 
 
 
 
 
 
 
6ec32ce
 
423ac96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ec32ce
 
423ac96
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
from transformers import AutoTokenizer, AutoModel
from sklearn.decomposition import PCA
import torch
import numpy as np
import plotly.graph_objects as go
import math

# Load transformer model + tokenizer..
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
    model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
    return tokenizer, model

tokenizer, model = load_model()

# Encode using mean pooling
def encode_texts(texts):
    with torch.no_grad():
        inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        output = model(**inputs)
        mask = inputs["attention_mask"].unsqueeze(-1).expand(output.last_hidden_state.shape).float()
        pooled = torch.sum(output.last_hidden_state * mask, dim=1) / mask.sum(dim=1)
        return pooled.cpu().numpy()

# Session state init
if "submitted_text" not in st.session_state:
    st.session_state.submitted_text = """BMW
Porsche
Mercedes
Coffee
Tea
Water
Germany
Italy
Brazil
Violin
Drums
Trumpet
Man
Women
Child"""

# UI layout
col1, col2 = st.columns([1, 3])

with col1:
    st.title("🧠 Embedding Input")

    with st.form(key="embedding_input_form"):
        st.form_submit_button("✅ Submit Text")
        st.text_area(
            label="Enter words (one per line)",
            key="submitted_text",
            height=400,
        )

texts = [t.strip() for t in st.session_state.submitted_text.split("\n") if t.strip()]
if len(texts) < 3:
    st.warning("Please enter at least three words.")
    st.stop()

embeddings = encode_texts(texts)
coords = PCA(n_components=3).fit_transform(embeddings)

# Rotation frames
frames = []
for angle in range(0, 360, 2):
    rad = math.radians(angle)
    camera = dict(eye=dict(x=2 * math.cos(rad), y=2 * math.sin(rad), z=0.7))
    frames.append(go.Frame(layout=dict(scene_camera=camera)))

# Plotly figure with animation controls
fig = go.Figure(
    data=[
        go.Scatter3d(
            x=coords[:, 0],
            y=coords[:, 1],
            z=coords[:, 2],
            mode="markers+text",
            text=texts,
            textposition="top center",
            textfont=dict(color="black"),
            marker=dict(size=6),
        )
    ],
    layout=go.Layout(
        title="3D Embedding Projection",
        scene=dict(
            xaxis=dict(title="X", showbackground=True, backgroundcolor="rgba(255,0,0,0.4)"),
            yaxis=dict(title="Y", showbackground=True, backgroundcolor="rgba(0,255,0,0.4)"),
            zaxis=dict(title="Z", showbackground=True, backgroundcolor="rgba(0,0,255,0.4)"),
        ),
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                buttons=[
                    dict(
                        label="🔄 Rotate",
                        method="animate",
                        args=[
                            None,
                            dict(
                                frame=dict(duration=50, redraw=True),
                                transition=dict(duration=0),
                                fromcurrent=True,
                                mode="immediate"
                            )
                        ],
                    )
                ],
                x=0.05,
                y=0.9
            )
        ],
        margin=dict(l=0, r=0, b=0, t=30),
    ),
    frames=frames
)

with col2:
    st.title("📊 3D Plot")
    st.plotly_chart(fig, use_container_width=True)