Spaces:
Runtime error
Runtime error
stuff
Browse files
app.py
CHANGED
|
@@ -1,161 +1,4 @@
|
|
| 1 |
# app.py
|
| 2 |
-
# Full Streamlit app:
|
| 3 |
-
# - Uses Hugging Face "transformers" (AutoTokenizer/AutoModel) with mean pooling
|
| 4 |
-
# - Generates 3 random mixed datasets at startup (21 items each; 7 from 3 random base sets)
|
| 5 |
-
# - Mixed set names are lowercase and look like "sports/countries/fruits"
|
| 6 |
-
# - Model selector + 2D/3D toggle + Info button on the same row (compact)
|
| 7 |
-
# - Local info page via st.query_params["page"] ("demo" or "info")
|
| 8 |
-
# - 3D view has opaque pastelimport streamlit as st
|
| 9 |
-
import plotly.graph_objects as go
|
| 10 |
-
import numpy as np
|
| 11 |
-
import random
|
| 12 |
-
from transformers import AutoTokenizer, AutoModel
|
| 13 |
-
import torch
|
| 14 |
-
from sklearn.decomposition import PCA
|
| 15 |
-
|
| 16 |
-
# ----------------------------
|
| 17 |
-
# CONFIG
|
| 18 |
-
# ----------------------------
|
| 19 |
-
st.set_page_config(page_title="Embedding Demo", layout="wide")
|
| 20 |
-
|
| 21 |
-
# embedding models
|
| 22 |
-
EMBED_MODELS = {
|
| 23 |
-
"all-minilm-l6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
| 24 |
-
"all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
|
| 25 |
-
"multi-qa-mpnet-base-dot-v1": "sentence-transformers/multi-qa-mpnet-base-dot-v1"
|
| 26 |
-
}
|
| 27 |
-
|
| 28 |
-
# base datasets
|
| 29 |
-
DATASETS = {
|
| 30 |
-
"countries": ["germany", "france", "italy", "spain", "portugal", "poland", "netherlands", "belgium", "austria", "switzerland"],
|
| 31 |
-
"animals": ["cat", "dog", "lion", "tiger", "bear", "wolf", "fox", "eagle", "shark", "whale"],
|
| 32 |
-
"furniture": ["armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser","nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner","ottoman","console table","vanity","buffet","sectional sofa"],
|
| 33 |
-
"actors": ["brad pitt","angelina jolie","meryl streep","leonardo dicaprio","tom hanks","scarlett johansson","robert de niro","natalie portman","matt damon","cate blanchett"],
|
| 34 |
-
"rock group": ["beatles","rolling stones","pink floyd","queen","led zeppelin","u2","ac/dc","nirvana","radiohead","metallica"]
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
-
# ----------------------------
|
| 38 |
-
# RANDOM MIXED SETS
|
| 39 |
-
# ----------------------------
|
| 40 |
-
def make_mixed_sets():
|
| 41 |
-
keys = list(DATASETS.keys())
|
| 42 |
-
mixed_sets = {}
|
| 43 |
-
for i in range(3):
|
| 44 |
-
src = random.sample(keys, 3)
|
| 45 |
-
words = []
|
| 46 |
-
for s in src:
|
| 47 |
-
words += random.sample(DATASETS[s], min(7, len(DATASETS[s])))
|
| 48 |
-
name = "/".join(src)
|
| 49 |
-
mixed_sets[name] = words
|
| 50 |
-
return mixed_sets
|
| 51 |
-
|
| 52 |
-
DATASETS.update(make_mixed_sets())
|
| 53 |
-
|
| 54 |
-
# ----------------------------
|
| 55 |
-
# EMBEDDING CACHING
|
| 56 |
-
# ----------------------------
|
| 57 |
-
@st.cache_resource
|
| 58 |
-
def load_model(model_name):
|
| 59 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 60 |
-
model = AutoModel.from_pretrained(model_name)
|
| 61 |
-
return tokenizer, model
|
| 62 |
-
|
| 63 |
-
def embed_words(tokenizer, model, words):
|
| 64 |
-
with torch.no_grad():
|
| 65 |
-
encoded = tokenizer(words, padding=True, truncation=True, return_tensors="pt")
|
| 66 |
-
model_output = model(**encoded)
|
| 67 |
-
embeddings = model_output.last_hidden_state.mean(dim=1)
|
| 68 |
-
return embeddings.numpy()
|
| 69 |
-
|
| 70 |
-
# ----------------------------
|
| 71 |
-
# INFO PAGE
|
| 72 |
-
# ----------------------------
|
| 73 |
-
def info_page():
|
| 74 |
-
st.markdown("""
|
| 75 |
-
## about this demo
|
| 76 |
-
embeddings are a way to turn words or sentences into lists of numbers.
|
| 77 |
-
these numbers are arranged so that similar meanings are placed close together in this space.
|
| 78 |
-
this makes it possible to compare meanings with math.
|
| 79 |
-
here you can see how different words are mapped into a 2d or 3d space,
|
| 80 |
-
so related words appear near each other.
|
| 81 |
-
""")
|
| 82 |
-
st.page_link("?", label="⬅ back to demo")
|
| 83 |
-
|
| 84 |
-
# ----------------------------
|
| 85 |
-
# MAIN DEMO
|
| 86 |
-
# ----------------------------
|
| 87 |
-
def main_page():
|
| 88 |
-
# top controls
|
| 89 |
-
col1, col2, col3 = st.columns([1.5, 1, 1])
|
| 90 |
-
with col1:
|
| 91 |
-
dataset_name = st.selectbox("dataset", list(DATASETS.keys()))
|
| 92 |
-
with col2:
|
| 93 |
-
embed_choice = st.selectbox("embedding model", list(EMBED_MODELS.keys()))
|
| 94 |
-
with col3:
|
| 95 |
-
proj_mode = st.radio("projection", ["2d", "3d"], horizontal=True)
|
| 96 |
-
|
| 97 |
-
# show dataset words
|
| 98 |
-
words = DATASETS[dataset_name]
|
| 99 |
-
st.text_area("dataset words", "\n".join(words), height=120)
|
| 100 |
-
|
| 101 |
-
# embedding
|
| 102 |
-
tokenizer, model = load_model(EMBED_MODELS[embed_choice])
|
| 103 |
-
vecs = embed_words(tokenizer, model, words)
|
| 104 |
-
|
| 105 |
-
# project
|
| 106 |
-
dims = 2 if proj_mode == "2d" else 3
|
| 107 |
-
if vecs.shape[1] > dims:
|
| 108 |
-
vecs = PCA(n_components=dims).fit_transform(vecs)
|
| 109 |
-
|
| 110 |
-
# plot
|
| 111 |
-
fig = go.Figure()
|
| 112 |
-
if dims == 3:
|
| 113 |
-
fig.add_trace(go.Scatter3d(
|
| 114 |
-
x=vecs[:,0], y=vecs[:,1], z=vecs[:,2],
|
| 115 |
-
mode='markers+text',
|
| 116 |
-
text=words,
|
| 117 |
-
textposition="top center"
|
| 118 |
-
))
|
| 119 |
-
# cube faces
|
| 120 |
-
cube_faces = [
|
| 121 |
-
dict(type="mesh3d",
|
| 122 |
-
x=[-1,1,1,-1], y=[-1,-1,1,1], z=[1,1,1,1],
|
| 123 |
-
color='rgba(255,182,193,1)'),
|
| 124 |
-
dict(type="mesh3d",
|
| 125 |
-
x=[-1,1,1,-1], y=[-1,-1,1,1], z=[-1,-1,-1,-1],
|
| 126 |
-
color='rgba(173,216,230,1)')
|
| 127 |
-
]
|
| 128 |
-
for face in cube_faces:
|
| 129 |
-
fig.add_trace(go.Mesh3d(**face))
|
| 130 |
-
fig.update_layout(scene_camera=dict(
|
| 131 |
-
eye=dict(x=1.5, y=1.5, z=1.5)
|
| 132 |
-
))
|
| 133 |
-
if st.button("rotate"):
|
| 134 |
-
fig.update_layout(scene_camera=dict(
|
| 135 |
-
eye=dict(x=1.5, y=1.5, z=1.5)
|
| 136 |
-
), transition=dict(duration=2000), uirevision=True)
|
| 137 |
-
else:
|
| 138 |
-
fig.add_trace(go.Scatter(
|
| 139 |
-
x=vecs[:,0], y=vecs[:,1],
|
| 140 |
-
mode='markers+text',
|
| 141 |
-
text=words,
|
| 142 |
-
textposition="top center"
|
| 143 |
-
))
|
| 144 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 145 |
-
|
| 146 |
-
# ----------------------------
|
| 147 |
-
# ROUTER
|
| 148 |
-
# ----------------------------
|
| 149 |
-
params = st.query_params
|
| 150 |
-
if params.get("page") == "info":
|
| 151 |
-
info_page()
|
| 152 |
-
else:
|
| 153 |
-
st.page_link("?page=info", label="ℹ info")
|
| 154 |
-
main_page()
|
| 155 |
-
cube faces
|
| 156 |
-
# - Smooth rotation that STARTS from the CURRENT manual view (captured via relayout)
|
| 157 |
-
# - Continuous rotation until toggled off
|
| 158 |
-
|
| 159 |
import time
|
| 160 |
import random
|
| 161 |
import numpy as np
|
|
@@ -165,65 +8,70 @@ from sklearn.decomposition import PCA
|
|
| 165 |
import torch
|
| 166 |
from transformers import AutoTokenizer, AutoModel
|
| 167 |
|
| 168 |
-
# Optional but recommended: capture Plotly relayout (camera) to start rotation from manual view
|
| 169 |
-
# pip install streamlit-plotly-events
|
| 170 |
-
try:
|
| 171 |
-
from streamlit_plotly_events import plotly_events
|
| 172 |
-
HAVE_EVENTS = True
|
| 173 |
-
except Exception:
|
| 174 |
-
HAVE_EVENTS = False
|
| 175 |
-
|
| 176 |
st.set_page_config(page_title="Embedding Demo", layout="wide")
|
| 177 |
|
| 178 |
-
# -----------------------
|
| 179 |
-
# BASE DATASETS
|
| 180 |
-
# -----------------------
|
| 181 |
DATASETS = {
|
| 182 |
-
"
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
}
|
| 197 |
|
| 198 |
-
# -----------------------
|
| 199 |
-
#
|
| 200 |
-
#
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
def create_random_mixed_sets(base_dict, n_sets=3):
|
| 204 |
-
base_keys = list(base_dict.keys())
|
| 205 |
mixed = {}
|
| 206 |
for _ in range(n_sets):
|
| 207 |
-
sources = random.sample(
|
| 208 |
items = []
|
| 209 |
for s in sources:
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
mixed[name] = items[:21]
|
| 215 |
return mixed
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
st.session_state.mixed_inserted = True
|
| 221 |
|
| 222 |
-
# -----------------------
|
| 223 |
# MODELS (transformers)
|
| 224 |
-
# -----------------------
|
| 225 |
EMBED_MODELS = {
|
| 226 |
-
"all-
|
| 227 |
"all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2",
|
| 228 |
"all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1",
|
| 229 |
}
|
|
@@ -242,222 +90,180 @@ def embed_texts(model_name: str, texts_tuple: tuple):
|
|
| 242 |
with torch.no_grad():
|
| 243 |
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
| 244 |
outputs = model(**inputs)
|
| 245 |
-
# Mean pooling with attention mask
|
| 246 |
token_embeddings = outputs.last_hidden_state # (B,T,H)
|
| 247 |
mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings)
|
| 248 |
summed = (token_embeddings * mask).sum(dim=1)
|
| 249 |
counts = mask.sum(dim=1).clamp(min=1e-9)
|
| 250 |
-
embeddings = summed / counts
|
| 251 |
return embeddings.cpu().numpy()
|
| 252 |
|
| 253 |
-
# -----------------------
|
| 254 |
-
# STATE
|
| 255 |
-
# -----------------------
|
| 256 |
-
if "
|
| 257 |
-
st.session_state.
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
st.session_state.
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
if "proj_mode" not in st.session_state:
|
| 269 |
-
st.session_state.proj_mode = "3D"
|
| 270 |
-
|
| 271 |
-
# -----------------------
|
| 272 |
-
# ROUTING: info or demo
|
| 273 |
-
# -----------------------
|
| 274 |
-
params = st.query_params
|
| 275 |
-
page = params.get("page", ["demo"])[0]
|
| 276 |
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
| 279 |
st.rerun()
|
| 280 |
|
| 281 |
-
|
|
|
|
|
|
|
| 282 |
# INFO PAGE
|
| 283 |
-
# -----------------------
|
| 284 |
-
def
|
| 285 |
-
st.title("ℹ
|
| 286 |
-
st.write(
|
| 287 |
-
|
| 288 |
-
**
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
- semantic search
|
| 292 |
-
- clustering
|
| 293 |
-
-
|
| 294 |
- measuring similarity and analogies
|
| 295 |
|
| 296 |
-
|
| 297 |
and shows how related words cluster in the projected space.
|
| 298 |
-
|
| 299 |
-
)
|
| 300 |
-
|
| 301 |
-
set_page("demo")
|
| 302 |
|
| 303 |
-
# -----------------------
|
| 304 |
# DEMO PAGE
|
| 305 |
-
# -----------------------
|
| 306 |
-
def
|
| 307 |
-
#
|
| 308 |
-
|
| 309 |
-
with
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
index=1 if st.session_state.proj_mode == "3D" else 0)
|
| 323 |
-
|
| 324 |
-
# Two-column layout: dataset+editor (left), plot (right)
|
| 325 |
-
left, right = st.columns([1, 3], gap="large")
|
| 326 |
-
|
| 327 |
-
with left:
|
| 328 |
-
ds_keys = list(DATASETS.keys())
|
| 329 |
-
st.session_state.dataset_select = st.selectbox(
|
| 330 |
-
"Dataset", ds_keys, index=ds_keys.index(st.session_state.dataset_select)
|
| 331 |
-
)
|
| 332 |
-
default_text = "\n".join(DATASETS[st.session_state.dataset_select])
|
| 333 |
-
user_text = st.text_area("Words (one per line)", value=default_text, height=300)
|
| 334 |
-
words = [w.strip() for w in user_text.split("\n") if w.strip()]
|
| 335 |
-
if len(words) < 3:
|
| 336 |
-
st.info("Enter at least three lines to project.")
|
| 337 |
-
st.stop()
|
| 338 |
-
|
| 339 |
-
# Rotate toggle + speed
|
| 340 |
-
r1, r2 = st.columns([1, 1])
|
| 341 |
-
with r1:
|
| 342 |
-
if st.button("⏯ Rotate"):
|
| 343 |
-
st.session_state.rotate_on = not st.session_state.rotate_on
|
| 344 |
-
with r2:
|
| 345 |
-
speed = st.slider("Speed (deg/frame)", 1, 10, 3, help="Rotation speed in degrees per frame (3 is smooth).")
|
| 346 |
-
|
| 347 |
-
with right:
|
| 348 |
-
# Embed (cached) and PCA
|
| 349 |
-
embs = embed_texts(st.session_state.model_name, tuple(words))
|
| 350 |
-
n_comp = 3 if st.session_state.proj_mode == "3D" else 2
|
| 351 |
-
coords = PCA(n_components=n_comp).fit_transform(embs)
|
| 352 |
-
|
| 353 |
-
# Title centered & blue
|
| 354 |
-
title_html = f"<b style='color:#1f77b4; font-size:2.2rem;'>{st.session_state.dataset_select}</b>"
|
| 355 |
-
|
| 356 |
-
# Build figure
|
| 357 |
-
if st.session_state.proj_mode == "3D":
|
| 358 |
-
eye = st.session_state.live_camera.get("eye", {"x": 1.8, "y": 0.0, "z": 1.0})
|
| 359 |
-
fig = go.Figure(
|
| 360 |
-
data=[go.Scatter3d(
|
| 361 |
-
x=coords[:, 0], y=coords[:, 1], z=coords[:, 2],
|
| 362 |
-
mode="markers+text", text=words, textposition="top center",
|
| 363 |
-
marker=dict(size=6),
|
| 364 |
-
)],
|
| 365 |
-
layout=go.Layout(
|
| 366 |
-
title=dict(text=title_html, x=0.5, xanchor="center", yanchor="top",
|
| 367 |
-
font=dict(size=32, color="#1f77b4")),
|
| 368 |
-
scene=dict(
|
| 369 |
-
camera=dict(eye=eye, projection=dict(type="perspective")),
|
| 370 |
-
xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"),
|
| 371 |
-
yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"),
|
| 372 |
-
zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"),
|
| 373 |
-
),
|
| 374 |
-
margin=dict(l=0, r=0, b=0, t=60),
|
| 375 |
-
uirevision="keep", # keep interactions across reruns
|
| 376 |
-
)
|
| 377 |
-
)
|
| 378 |
-
else:
|
| 379 |
-
fig = go.Figure(
|
| 380 |
-
data=[go.Scatter(
|
| 381 |
-
x=coords[:, 0], y=coords[:, 1],
|
| 382 |
-
mode="markers+text", text=words, textposition="top center",
|
| 383 |
-
marker=dict(size=10),
|
| 384 |
-
)],
|
| 385 |
-
layout=go.Layout(
|
| 386 |
-
title=dict(text=title_html, x=0.5, xanchor="center", yanchor="top",
|
| 387 |
-
font=dict(size=32, color="#1f77b4")),
|
| 388 |
-
xaxis=dict(title="PC1"),
|
| 389 |
-
yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1),
|
| 390 |
-
margin=dict(l=0, r=0, b=0, t=60),
|
| 391 |
-
)
|
| 392 |
-
)
|
| 393 |
|
| 394 |
-
|
| 395 |
-
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
)
|
| 404 |
-
|
| 405 |
-
rel_key = "plotly_relayout-data-plot3d"
|
| 406 |
-
rel = st.session_state.get(rel_key)
|
| 407 |
-
if isinstance(rel, dict):
|
| 408 |
-
cam_update = rel.get("scene.camera") or {}
|
| 409 |
-
if "eye" in cam_update:
|
| 410 |
-
st.session_state.live_camera = {
|
| 411 |
-
"eye": cam_update["eye"],
|
| 412 |
-
"projection": {"type": "perspective"}
|
| 413 |
-
}
|
| 414 |
-
else:
|
| 415 |
-
# Sometimes relayout sends separate eye components
|
| 416 |
-
keys = ["scene.camera.eye.x", "scene.camera.eye.y", "scene.camera.eye.z"]
|
| 417 |
-
if all(k in rel for k in keys):
|
| 418 |
-
st.session_state.live_camera = {
|
| 419 |
-
"eye": {"x": rel[keys[0]], "y": rel[keys[1]], "z": rel[keys[2]]},
|
| 420 |
-
"projection": {"type": "perspective"}
|
| 421 |
-
}
|
| 422 |
-
else:
|
| 423 |
-
ph.plotly_chart(fig, use_container_width=True)
|
| 424 |
-
|
| 425 |
-
# Continuous rotation loop (only in 3D and when rotate_on is True)
|
| 426 |
-
if st.session_state.proj_mode == "3D" and st.session_state.rotate_on:
|
| 427 |
-
# Start from the CURRENT live camera (which includes any manual rotation)
|
| 428 |
-
eye = st.session_state.live_camera["eye"]
|
| 429 |
-
r = max(1e-6, float(np.sqrt(eye["x"]**2 + eye["y"]**2)))
|
| 430 |
-
z_eye = float(eye["z"])
|
| 431 |
-
# Current angle in radians based on x,y
|
| 432 |
-
angle = float(np.arctan2(eye["y"], eye["x"]))
|
| 433 |
-
|
| 434 |
-
# Render a batch of frames, then rerun to keep UI responsive
|
| 435 |
-
frames = 180 # ~6 seconds at 33ms per frame
|
| 436 |
-
step = np.deg2rad(speed)
|
| 437 |
-
|
| 438 |
-
for _ in range(frames):
|
| 439 |
-
# Update angle and camera eye
|
| 440 |
-
angle += step
|
| 441 |
-
new_eye = {"x": r * np.cos(angle), "y": r * np.sin(angle), "z": z_eye}
|
| 442 |
-
st.session_state.live_camera = {"eye": new_eye, "projection": {"type": "perspective"}}
|
| 443 |
-
fig.update_layout(scene_camera=st.session_state.live_camera)
|
| 444 |
-
|
| 445 |
-
# Render updated frame
|
| 446 |
-
ph.plotly_chart(fig, use_container_width=True)
|
| 447 |
-
time.sleep(0.033) # ~30 FPS
|
| 448 |
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
-
# If still
|
| 454 |
-
if st.session_state.
|
| 455 |
st.rerun()
|
| 456 |
|
| 457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
# ROUTER
|
| 459 |
-
# -----------------------
|
| 460 |
if page == "info":
|
| 461 |
-
|
| 462 |
else:
|
| 463 |
-
|
|
|
|
| 1 |
# app.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import time
|
| 3 |
import random
|
| 4 |
import numpy as np
|
|
|
|
| 8 |
import torch
|
| 9 |
from transformers import AutoTokenizer, AutoModel
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
st.set_page_config(page_title="Embedding Demo", layout="wide")
|
| 12 |
|
| 13 |
+
# ----------------------------
|
| 14 |
+
# BASE DATASETS (lowercase)
|
| 15 |
+
# ----------------------------
|
| 16 |
DATASETS = {
|
| 17 |
+
"countries": [
|
| 18 |
+
"germany","france","italy","spain","portugal","poland","netherlands","belgium",
|
| 19 |
+
"austria","switzerland","greece","norway","sweden","finland","denmark","ireland",
|
| 20 |
+
"hungary","czechia","slovakia","slovenia","iceland","estonia","latvia","lithuania","romania"
|
| 21 |
+
],
|
| 22 |
+
"animals": [
|
| 23 |
+
"cat","dog","lion","tiger","bear","wolf","fox","eagle","shark","whale",
|
| 24 |
+
"zebra","giraffe","elephant","hippopotamus","rhinoceros","kangaroo","panda","otter","seal","dolphin",
|
| 25 |
+
"chimpanzee","gorilla","leopard","cheetah","lynx"
|
| 26 |
+
],
|
| 27 |
+
"furniture": [
|
| 28 |
+
"armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser",
|
| 29 |
+
"nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner",
|
| 30 |
+
"ottoman","console table","vanity","buffet","sectional sofa"
|
| 31 |
+
],
|
| 32 |
+
"actors": [
|
| 33 |
+
"brad pitt","angelina jolie","meryl streep","leonardo dicaprio","tom hanks","scarlett johansson","robert de niro",
|
| 34 |
+
"natalie portman","matt damon","cate blanchett","johnny depp","keanu reeves","hugh jackman","emma stone","ryan gosling",
|
| 35 |
+
"jennifer lawrence","christian bale","charlize theron","will smith","anne hathaway","denzel washington","morgan freeman",
|
| 36 |
+
"julia roberts","george clooney","kate winslet"
|
| 37 |
+
],
|
| 38 |
+
"rock group": [
|
| 39 |
+
"the beatles","rolling stones","pink floyd","queen","led zeppelin","u2","ac/dc","nirvana","radiohead","metallica",
|
| 40 |
+
"guns n' roses","red hot chili peppers","coldplay","pearl jam","the police","aerosmith","green day","foo fighters",
|
| 41 |
+
"the doors","bon jovi","deep purple","the who","the kinks","fleetwood mac","the beach boys"
|
| 42 |
+
],
|
| 43 |
+
"sports": [
|
| 44 |
+
"soccer","basketball","tennis","baseball","golf","swimming","cycling","running","volleyball","rugby",
|
| 45 |
+
"boxing","skiing","snowboarding","surfing","skateboarding","karate","judo","fencing","rowing","badminton",
|
| 46 |
+
"cricket","table tennis","gymnastics","hockey","climbing"
|
| 47 |
+
]
|
| 48 |
}
|
| 49 |
|
| 50 |
+
# ----------------------------
|
| 51 |
+
# RANDOM MIXED SETS (once per session)
|
| 52 |
+
# ----------------------------
|
| 53 |
+
def make_random_mixed_sets(base: dict, n_sets: int = 3) -> dict:
|
| 54 |
+
keys = list(base.keys())
|
|
|
|
|
|
|
| 55 |
mixed = {}
|
| 56 |
for _ in range(n_sets):
|
| 57 |
+
sources = random.sample(keys, 3)
|
| 58 |
items = []
|
| 59 |
for s in sources:
|
| 60 |
+
take = min(7, len(base[s]))
|
| 61 |
+
items.extend(random.sample(base[s], take))
|
| 62 |
+
mixed_name = "/".join(sources).lower()
|
| 63 |
+
mixed[mixed_name] = items[:21]
|
|
|
|
| 64 |
return mixed
|
| 65 |
|
| 66 |
+
if "mixed_added" not in st.session_state:
|
| 67 |
+
DATASETS.update(make_random_mixed_sets(DATASETS, 3))
|
| 68 |
+
st.session_state.mixed_added = True
|
|
|
|
| 69 |
|
| 70 |
+
# ----------------------------
|
| 71 |
# MODELS (transformers)
|
| 72 |
+
# ----------------------------
|
| 73 |
EMBED_MODELS = {
|
| 74 |
+
"all-minilm-l6-v2 (384d)": "sentence-transformers/all-MiniLM-L6-v2",
|
| 75 |
"all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2",
|
| 76 |
"all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1",
|
| 77 |
}
|
|
|
|
| 90 |
with torch.no_grad():
|
| 91 |
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
| 92 |
outputs = model(**inputs)
|
|
|
|
| 93 |
token_embeddings = outputs.last_hidden_state # (B,T,H)
|
| 94 |
mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings)
|
| 95 |
summed = (token_embeddings * mask).sum(dim=1)
|
| 96 |
counts = mask.sum(dim=1).clamp(min=1e-9)
|
| 97 |
+
embeddings = summed / counts # mean pooling
|
| 98 |
return embeddings.cpu().numpy()
|
| 99 |
|
| 100 |
+
# ----------------------------
|
| 101 |
+
# STATE: camera + rotation
|
| 102 |
+
# ----------------------------
|
| 103 |
+
if "camera_eye" not in st.session_state:
|
| 104 |
+
st.session_state.camera_eye = {"x": 1.6, "y": 1.6, "z": 1.2}
|
| 105 |
+
if "spinning" not in st.session_state:
|
| 106 |
+
st.session_state.spinning = False
|
| 107 |
+
if "angle_rad" not in st.session_state:
|
| 108 |
+
# derive from eye.x, eye.y
|
| 109 |
+
e = st.session_state.camera_eye
|
| 110 |
+
st.session_state.angle_rad = float(np.arctan2(e["y"], e["x"]))
|
| 111 |
+
|
| 112 |
+
def update_eye_from_angle(angle_rad: float, radius: float, z: float):
|
| 113 |
+
return {"x": radius * np.cos(angle_rad), "y": radius * np.sin(angle_rad), "z": z}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
# ----------------------------
|
| 116 |
+
# NAVIGATION via st.query_params
|
| 117 |
+
# ----------------------------
|
| 118 |
+
def goto(page: str):
|
| 119 |
+
st.query_params["page"] = page
|
| 120 |
st.rerun()
|
| 121 |
|
| 122 |
+
page = st.query_params.get("page", ["demo"])[0]
|
| 123 |
+
|
| 124 |
+
# ----------------------------
|
| 125 |
# INFO PAGE
|
| 126 |
+
# ----------------------------
|
| 127 |
+
def info_page():
|
| 128 |
+
st.title("ℹ about this demo")
|
| 129 |
+
st.write("""
|
| 130 |
+
**embeddings** turn words (or longer text) into numerical vectors.
|
| 131 |
+
in this vector space, **semantically related** items end up **near** each other.
|
| 132 |
+
|
| 133 |
+
why this is useful:
|
| 134 |
+
- semantic search and retrieval
|
| 135 |
+
- clustering and topic discovery
|
| 136 |
+
- recommendations and deduplication
|
| 137 |
- measuring similarity and analogies
|
| 138 |
|
| 139 |
+
this demo embeds single words with a selectable model, reduces to 2d/3d with pca,
|
| 140 |
and shows how related words cluster in the projected space.
|
| 141 |
+
""".strip())
|
| 142 |
+
if st.button("⬅ back to demo"):
|
| 143 |
+
goto("demo")
|
|
|
|
| 144 |
|
| 145 |
+
# ----------------------------
|
| 146 |
# DEMO PAGE
|
| 147 |
+
# ----------------------------
|
| 148 |
+
def demo_page():
|
| 149 |
+
# top row: dataset, model + 2d/3d, info button
|
| 150 |
+
c1, c2, c3 = st.columns([2, 2, 1])
|
| 151 |
+
with c1:
|
| 152 |
+
ds_names = list(DATASETS.keys())
|
| 153 |
+
dataset_name = st.selectbox("dataset", ds_names, index=ds_names.index("furniture") if "furniture" in ds_names else 0)
|
| 154 |
+
with c2:
|
| 155 |
+
cc1, cc2 = st.columns([2, 1])
|
| 156 |
+
with cc1:
|
| 157 |
+
model_label = st.selectbox("embedding model", list(EMBED_MODELS.keys()))
|
| 158 |
+
model_name = EMBED_MODELS[model_label]
|
| 159 |
+
with cc2:
|
| 160 |
+
proj_mode = st.radio("projection", ["2d", "3d"], horizontal=True)
|
| 161 |
+
with c3:
|
| 162 |
+
if st.button("ℹ info"):
|
| 163 |
+
goto("info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
words = DATASETS[dataset_name]
|
| 166 |
+
st.text_area("dataset words", "\n".join(words), height=160)
|
| 167 |
|
| 168 |
+
# Embed + PCA
|
| 169 |
+
embs = embed_texts(model_name, tuple(words))
|
| 170 |
+
if proj_mode == "2d":
|
| 171 |
+
coords = PCA(n_components=2).fit_transform(embs)
|
| 172 |
+
else:
|
| 173 |
+
coords = PCA(n_components=3).fit_transform(embs)
|
| 174 |
+
|
| 175 |
+
title_html = f"<b style='color:#1f77b4; font-size:2.0rem;'>{dataset_name}</b>"
|
| 176 |
+
|
| 177 |
+
if proj_mode == "3d":
|
| 178 |
+
# compute radius from current eye, keep same z
|
| 179 |
+
eye = st.session_state.camera_eye
|
| 180 |
+
radius = float(np.sqrt(eye["x"]**2 + eye["y"]**2)) or 1.6
|
| 181 |
+
z_eye = float(eye["z"])
|
| 182 |
+
|
| 183 |
+
fig = go.Figure(
|
| 184 |
+
data=[go.Scatter3d(
|
| 185 |
+
x=coords[:, 0], y=coords[:, 1], z=coords[:, 2],
|
| 186 |
+
mode="markers+text", text=words, textposition="top center",
|
| 187 |
+
marker=dict(size=6),
|
| 188 |
+
)],
|
| 189 |
+
layout=go.Layout(
|
| 190 |
+
title=dict(text=title_html, x=0.5, xanchor="center", yanchor="top",
|
| 191 |
+
font=dict(size=30, color="#1f77b4")),
|
| 192 |
+
scene=dict(
|
| 193 |
+
camera=dict(eye=eye, projection=dict(type="perspective")),
|
| 194 |
+
xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"),
|
| 195 |
+
yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"),
|
| 196 |
+
zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"),
|
| 197 |
+
),
|
| 198 |
+
margin=dict(l=0, r=0, b=0, t=60),
|
| 199 |
+
uirevision="keep",
|
| 200 |
)
|
| 201 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
# Controls under the plot: Start/Stop rotation
|
| 204 |
+
b1, b2 = st.columns([1, 1])
|
| 205 |
+
with b1:
|
| 206 |
+
start_clicked = st.button("▶ start rotation", disabled=st.session_state.spinning)
|
| 207 |
+
with b2:
|
| 208 |
+
stop_clicked = st.button("⏹ stop rotation", disabled=not st.session_state.spinning)
|
| 209 |
+
|
| 210 |
+
# If start pressed: turn on spinner and initialize angle from current stored eye
|
| 211 |
+
if start_clicked:
|
| 212 |
+
st.session_state.spinning = True
|
| 213 |
+
# start from stored angle (not capturing manual camera — simple approach)
|
| 214 |
+
st.session_state.angle_rad = float(np.arctan2(eye["y"], eye["x"]))
|
| 215 |
+
# fall through to loop below (this turn will render once, then continue)
|
| 216 |
+
|
| 217 |
+
# If stop pressed: turn off spinner (and keep stop disabled after)
|
| 218 |
+
if stop_clicked:
|
| 219 |
+
st.session_state.spinning = False
|
| 220 |
+
|
| 221 |
+
# Live render placeholder
|
| 222 |
+
placeholder = st.empty()
|
| 223 |
+
|
| 224 |
+
# First draw (static) before any loop
|
| 225 |
+
placeholder.plotly_chart(fig, use_container_width=True)
|
| 226 |
+
|
| 227 |
+
# Continuous rotation loop while spinning
|
| 228 |
+
if st.session_state.spinning:
|
| 229 |
+
# one "batch" of frames, then rerun to keep UI responsive
|
| 230 |
+
steps_per_batch = 120
|
| 231 |
+
step = np.deg2rad(3) # 3 degrees per frame ~ smooth
|
| 232 |
+
for _ in range(steps_per_batch):
|
| 233 |
+
if not st.session_state.spinning:
|
| 234 |
break
|
| 235 |
+
st.session_state.angle_rad += step
|
| 236 |
+
new_eye = update_eye_from_angle(st.session_state.angle_rad, radius, z_eye)
|
| 237 |
+
st.session_state.camera_eye = new_eye
|
| 238 |
+
fig.update_layout(scene_camera=dict(eye=new_eye, projection=dict(type="perspective")))
|
| 239 |
+
placeholder.plotly_chart(fig, use_container_width=True)
|
| 240 |
+
time.sleep(0.033) # ~30 FPS
|
| 241 |
|
| 242 |
+
# If still spinning after this batch, rerun to keep going
|
| 243 |
+
if st.session_state.spinning:
|
| 244 |
st.rerun()
|
| 245 |
|
| 246 |
+
else:
|
| 247 |
+
fig = go.Figure(
|
| 248 |
+
data=[go.Scatter(
|
| 249 |
+
x=coords[:, 0], y=coords[:, 1],
|
| 250 |
+
mode="markers+text", text=words, textposition="top center",
|
| 251 |
+
marker=dict(size=9),
|
| 252 |
+
)],
|
| 253 |
+
layout=go.Layout(
|
| 254 |
+
title=dict(text=title_html, x=0.5, xanchor="center", yanchor="top",
|
| 255 |
+
font=dict(size=30, color="#1f77b4")),
|
| 256 |
+
xaxis=dict(title="PC1"),
|
| 257 |
+
yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1),
|
| 258 |
+
margin=dict(l=0, r=0, b=0, t=60),
|
| 259 |
+
)
|
| 260 |
+
)
|
| 261 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 262 |
+
|
| 263 |
+
# ----------------------------
|
| 264 |
# ROUTER
|
| 265 |
+
# ----------------------------
|
| 266 |
if page == "info":
|
| 267 |
+
info_page()
|
| 268 |
else:
|
| 269 |
+
demo_page()
|