berndf commited on
Commit
6449dfd
·
verified ·
1 Parent(s): f849a1a
Files changed (1) hide show
  1. app.py +201 -395
app.py CHANGED
@@ -1,161 +1,4 @@
1
  # app.py
2
- # Full Streamlit app:
3
- # - Uses Hugging Face "transformers" (AutoTokenizer/AutoModel) with mean pooling
4
- # - Generates 3 random mixed datasets at startup (21 items each; 7 from 3 random base sets)
5
- # - Mixed set names are lowercase and look like "sports/countries/fruits"
6
- # - Model selector + 2D/3D toggle + Info button on the same row (compact)
7
- # - Local info page via st.query_params["page"] ("demo" or "info")
8
- # - 3D view has opaque pastelimport streamlit as st
9
- import plotly.graph_objects as go
10
- import numpy as np
11
- import random
12
- from transformers import AutoTokenizer, AutoModel
13
- import torch
14
- from sklearn.decomposition import PCA
15
-
16
- # ----------------------------
17
- # CONFIG
18
- # ----------------------------
19
- st.set_page_config(page_title="Embedding Demo", layout="wide")
20
-
21
- # embedding models
22
- EMBED_MODELS = {
23
- "all-minilm-l6-v2": "sentence-transformers/all-MiniLM-L6-v2",
24
- "all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
25
- "multi-qa-mpnet-base-dot-v1": "sentence-transformers/multi-qa-mpnet-base-dot-v1"
26
- }
27
-
28
- # base datasets
29
- DATASETS = {
30
- "countries": ["germany", "france", "italy", "spain", "portugal", "poland", "netherlands", "belgium", "austria", "switzerland"],
31
- "animals": ["cat", "dog", "lion", "tiger", "bear", "wolf", "fox", "eagle", "shark", "whale"],
32
- "furniture": ["armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser","nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner","ottoman","console table","vanity","buffet","sectional sofa"],
33
- "actors": ["brad pitt","angelina jolie","meryl streep","leonardo dicaprio","tom hanks","scarlett johansson","robert de niro","natalie portman","matt damon","cate blanchett"],
34
- "rock group": ["beatles","rolling stones","pink floyd","queen","led zeppelin","u2","ac/dc","nirvana","radiohead","metallica"]
35
- }
36
-
37
- # ----------------------------
38
- # RANDOM MIXED SETS
39
- # ----------------------------
40
- def make_mixed_sets():
41
- keys = list(DATASETS.keys())
42
- mixed_sets = {}
43
- for i in range(3):
44
- src = random.sample(keys, 3)
45
- words = []
46
- for s in src:
47
- words += random.sample(DATASETS[s], min(7, len(DATASETS[s])))
48
- name = "/".join(src)
49
- mixed_sets[name] = words
50
- return mixed_sets
51
-
52
- DATASETS.update(make_mixed_sets())
53
-
54
- # ----------------------------
55
- # EMBEDDING CACHING
56
- # ----------------------------
57
- @st.cache_resource
58
- def load_model(model_name):
59
- tokenizer = AutoTokenizer.from_pretrained(model_name)
60
- model = AutoModel.from_pretrained(model_name)
61
- return tokenizer, model
62
-
63
- def embed_words(tokenizer, model, words):
64
- with torch.no_grad():
65
- encoded = tokenizer(words, padding=True, truncation=True, return_tensors="pt")
66
- model_output = model(**encoded)
67
- embeddings = model_output.last_hidden_state.mean(dim=1)
68
- return embeddings.numpy()
69
-
70
- # ----------------------------
71
- # INFO PAGE
72
- # ----------------------------
73
- def info_page():
74
- st.markdown("""
75
- ## about this demo
76
- embeddings are a way to turn words or sentences into lists of numbers.
77
- these numbers are arranged so that similar meanings are placed close together in this space.
78
- this makes it possible to compare meanings with math.
79
- here you can see how different words are mapped into a 2d or 3d space,
80
- so related words appear near each other.
81
- """)
82
- st.page_link("?", label="⬅ back to demo")
83
-
84
- # ----------------------------
85
- # MAIN DEMO
86
- # ----------------------------
87
- def main_page():
88
- # top controls
89
- col1, col2, col3 = st.columns([1.5, 1, 1])
90
- with col1:
91
- dataset_name = st.selectbox("dataset", list(DATASETS.keys()))
92
- with col2:
93
- embed_choice = st.selectbox("embedding model", list(EMBED_MODELS.keys()))
94
- with col3:
95
- proj_mode = st.radio("projection", ["2d", "3d"], horizontal=True)
96
-
97
- # show dataset words
98
- words = DATASETS[dataset_name]
99
- st.text_area("dataset words", "\n".join(words), height=120)
100
-
101
- # embedding
102
- tokenizer, model = load_model(EMBED_MODELS[embed_choice])
103
- vecs = embed_words(tokenizer, model, words)
104
-
105
- # project
106
- dims = 2 if proj_mode == "2d" else 3
107
- if vecs.shape[1] > dims:
108
- vecs = PCA(n_components=dims).fit_transform(vecs)
109
-
110
- # plot
111
- fig = go.Figure()
112
- if dims == 3:
113
- fig.add_trace(go.Scatter3d(
114
- x=vecs[:,0], y=vecs[:,1], z=vecs[:,2],
115
- mode='markers+text',
116
- text=words,
117
- textposition="top center"
118
- ))
119
- # cube faces
120
- cube_faces = [
121
- dict(type="mesh3d",
122
- x=[-1,1,1,-1], y=[-1,-1,1,1], z=[1,1,1,1],
123
- color='rgba(255,182,193,1)'),
124
- dict(type="mesh3d",
125
- x=[-1,1,1,-1], y=[-1,-1,1,1], z=[-1,-1,-1,-1],
126
- color='rgba(173,216,230,1)')
127
- ]
128
- for face in cube_faces:
129
- fig.add_trace(go.Mesh3d(**face))
130
- fig.update_layout(scene_camera=dict(
131
- eye=dict(x=1.5, y=1.5, z=1.5)
132
- ))
133
- if st.button("rotate"):
134
- fig.update_layout(scene_camera=dict(
135
- eye=dict(x=1.5, y=1.5, z=1.5)
136
- ), transition=dict(duration=2000), uirevision=True)
137
- else:
138
- fig.add_trace(go.Scatter(
139
- x=vecs[:,0], y=vecs[:,1],
140
- mode='markers+text',
141
- text=words,
142
- textposition="top center"
143
- ))
144
- st.plotly_chart(fig, use_container_width=True)
145
-
146
- # ----------------------------
147
- # ROUTER
148
- # ----------------------------
149
- params = st.query_params
150
- if params.get("page") == "info":
151
- info_page()
152
- else:
153
- st.page_link("?page=info", label="ℹ info")
154
- main_page()
155
- cube faces
156
- # - Smooth rotation that STARTS from the CURRENT manual view (captured via relayout)
157
- # - Continuous rotation until toggled off
158
-
159
  import time
160
  import random
161
  import numpy as np
@@ -165,65 +8,70 @@ from sklearn.decomposition import PCA
165
  import torch
166
  from transformers import AutoTokenizer, AutoModel
167
 
168
- # Optional but recommended: capture Plotly relayout (camera) to start rotation from manual view
169
- # pip install streamlit-plotly-events
170
- try:
171
- from streamlit_plotly_events import plotly_events
172
- HAVE_EVENTS = True
173
- except Exception:
174
- HAVE_EVENTS = False
175
-
176
  st.set_page_config(page_title="Embedding Demo", layout="wide")
177
 
178
- # -----------------------
179
- # BASE DATASETS
180
- # -----------------------
181
  DATASETS = {
182
- "sports": ["Soccer","Basketball","Tennis","Baseball","Golf","Swimming","Cycling","Running","Volleyball","Rugby",
183
- "Boxing","Skiing","Surfing","Skateboarding","Hiking","Rowing","Fencing","Gymnastics","Badminton","Cricket","Wrestling"],
184
- "countries": ["Germany","France","Italy","Spain","Portugal","Netherlands","Belgium","Sweden","Norway","Denmark",
185
- "Finland","Poland","Austria","Switzerland","Greece","Turkey","Ireland","Hungary","Czechia","Slovakia","Slovenia"],
186
- "animals": ["Cat","Dog","Elephant","Tiger","Lion","Horse","Cow","Sheep","Goat","Monkey","Bear","Wolf","Deer","Kangaroo",
187
- "Panda","Rabbit","Fox","Giraffe","Zebra","Hippopotamus","Crocodile"],
188
- "furniture": ["Armchair","Sofa","Dining table","Coffee table","Bookshelf","Bed","Wardrobe","Desk","Office chair","Dresser",
189
- "Nightstand","Side table","TV stand","Loveseat","Chaise lounge","Bench","Hutch","Kitchen island","Futon","Recliner",
190
- "Ottoman","Console table","Vanity","Buffet","Sectional sofa"],
191
- "actor": ["Leonardo DiCaprio","Brad Pitt","Tom Hanks","Johnny Depp","Robert De Niro","Al Pacino","Morgan Freeman","Denzel Washington","Tom Cruise",
192
- "Will Smith","Matt Damon","Harrison Ford","George Clooney","Christian Bale","Keanu Reeves","Russell Crowe","Hugh Jackman",
193
- "Samuel L. Jackson","Anthony Hopkins","Mark Wahlberg","Edward Norton"],
194
- "rock group": ["Queen","The Beatles","Rolling Stones","Pink Floyd","Led Zeppelin","AC/DC","Nirvana","Metallica","U2","The Who",
195
- "Guns N' Roses","Coldplay","Red Hot Chili Peppers","The Doors","Radiohead","Pearl Jam","The Police","Aerosmith","Deep Purple","Oasis","Foo Fighters"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  }
197
 
198
- # -----------------------
199
- # CREATE 3 RANDOM MIXED SETS ON STARTUP
200
- # Each: pick 3 distinct source sets; 7 items from each (21 total)
201
- # Name: "set1/set2/set3" (lowercase)
202
- # -----------------------
203
- def create_random_mixed_sets(base_dict, n_sets=3):
204
- base_keys = list(base_dict.keys())
205
  mixed = {}
206
  for _ in range(n_sets):
207
- sources = random.sample(base_keys, 3)
208
  items = []
209
  for s in sources:
210
- # Guard: some sets might have fewer than 7 (we do min)
211
- k = min(7, len(base_dict[s]))
212
- items.extend(random.sample(base_dict[s], k))
213
- name = "/".join(sources).lower()
214
- mixed[name] = items[:21]
215
  return mixed
216
 
217
- # Only generate once per process
218
- if "mixed_inserted" not in st.session_state:
219
- DATASETS.update(create_random_mixed_sets(DATASETS, n_sets=3))
220
- st.session_state.mixed_inserted = True
221
 
222
- # -----------------------
223
  # MODELS (transformers)
224
- # -----------------------
225
  EMBED_MODELS = {
226
- "all-MiniLM-L6-v2 (384d)": "sentence-transformers/all-MiniLM-L6-v2",
227
  "all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2",
228
  "all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1",
229
  }
@@ -242,222 +90,180 @@ def embed_texts(model_name: str, texts_tuple: tuple):
242
  with torch.no_grad():
243
  inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
244
  outputs = model(**inputs)
245
- # Mean pooling with attention mask
246
  token_embeddings = outputs.last_hidden_state # (B,T,H)
247
  mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings)
248
  summed = (token_embeddings * mask).sum(dim=1)
249
  counts = mask.sum(dim=1).clamp(min=1e-9)
250
- embeddings = summed / counts
251
  return embeddings.cpu().numpy()
252
 
253
- # -----------------------
254
- # STATE FOR CAMERA + ROTATION
255
- # -----------------------
256
- if "live_camera" not in st.session_state:
257
- st.session_state.live_camera = {"eye": {"x": 1.8, "y": 0.0, "z": 1.0}, "projection": {"type": "perspective"}}
258
-
259
- if "rotate_on" not in st.session_state:
260
- st.session_state.rotate_on = False
261
-
262
- if "dataset_select" not in st.session_state:
263
- st.session_state.dataset_select = "furniture"
264
-
265
- if "model_name" not in st.session_state:
266
- st.session_state.model_name = list(EMBED_MODELS.values())[0]
267
-
268
- if "proj_mode" not in st.session_state:
269
- st.session_state.proj_mode = "3D"
270
-
271
- # -----------------------
272
- # ROUTING: info or demo
273
- # -----------------------
274
- params = st.query_params
275
- page = params.get("page", ["demo"])[0]
276
 
277
- def set_page(name: str):
278
- st.query_params["page"] = name
 
 
 
279
  st.rerun()
280
 
281
- # -----------------------
 
 
282
  # INFO PAGE
283
- # -----------------------
284
- def show_info():
285
- st.title("ℹ About this demo")
286
- st.write(
287
- """
288
- **Embeddings** turn words (or longer text) into numerical vectors.
289
- In this vector space, **semantically related** items end up **near** each other.
290
- That property makes embeddings useful for:
291
- - semantic search & retrieval
292
- - clustering & topic discovery
293
- - recommendation & deduplication
294
  - measuring similarity and analogies
295
 
296
- This demo embeds single words with a selectable model, reduces to 2D/3D with PCA,
297
  and shows how related words cluster in the projected space.
298
- """
299
- )
300
- if st.button("⬅ Back to demo"):
301
- set_page("demo")
302
 
303
- # -----------------------
304
  # DEMO PAGE
305
- # -----------------------
306
- def show_demo():
307
- # Header row: model selector (left), info button (middle), 2D/3D toggle (right)
308
- h1, h2, h3 = st.columns([3, 1, 1])
309
- with h1:
310
- labels = list(EMBED_MODELS.keys())
311
- name_map = EMBED_MODELS
312
- # get current label from state value
313
- rev = {v: k for k, v in name_map.items()}
314
- current_label = rev.get(st.session_state.model_name, labels[0])
315
- chosen_label = st.selectbox("Embedding model", labels, index=labels.index(current_label))
316
- st.session_state.model_name = name_map[chosen_label]
317
- with h2:
318
- if st.button(" Info"):
319
- set_page("info")
320
- with h3:
321
- st.session_state.proj_mode = st.radio("", ["2D", "3D"], horizontal=True,
322
- index=1 if st.session_state.proj_mode == "3D" else 0)
323
-
324
- # Two-column layout: dataset+editor (left), plot (right)
325
- left, right = st.columns([1, 3], gap="large")
326
-
327
- with left:
328
- ds_keys = list(DATASETS.keys())
329
- st.session_state.dataset_select = st.selectbox(
330
- "Dataset", ds_keys, index=ds_keys.index(st.session_state.dataset_select)
331
- )
332
- default_text = "\n".join(DATASETS[st.session_state.dataset_select])
333
- user_text = st.text_area("Words (one per line)", value=default_text, height=300)
334
- words = [w.strip() for w in user_text.split("\n") if w.strip()]
335
- if len(words) < 3:
336
- st.info("Enter at least three lines to project.")
337
- st.stop()
338
-
339
- # Rotate toggle + speed
340
- r1, r2 = st.columns([1, 1])
341
- with r1:
342
- if st.button("⏯ Rotate"):
343
- st.session_state.rotate_on = not st.session_state.rotate_on
344
- with r2:
345
- speed = st.slider("Speed (deg/frame)", 1, 10, 3, help="Rotation speed in degrees per frame (3 is smooth).")
346
-
347
- with right:
348
- # Embed (cached) and PCA
349
- embs = embed_texts(st.session_state.model_name, tuple(words))
350
- n_comp = 3 if st.session_state.proj_mode == "3D" else 2
351
- coords = PCA(n_components=n_comp).fit_transform(embs)
352
-
353
- # Title centered & blue
354
- title_html = f"<b style='color:#1f77b4; font-size:2.2rem;'>{st.session_state.dataset_select}</b>"
355
-
356
- # Build figure
357
- if st.session_state.proj_mode == "3D":
358
- eye = st.session_state.live_camera.get("eye", {"x": 1.8, "y": 0.0, "z": 1.0})
359
- fig = go.Figure(
360
- data=[go.Scatter3d(
361
- x=coords[:, 0], y=coords[:, 1], z=coords[:, 2],
362
- mode="markers+text", text=words, textposition="top center",
363
- marker=dict(size=6),
364
- )],
365
- layout=go.Layout(
366
- title=dict(text=title_html, x=0.5, xanchor="center", yanchor="top",
367
- font=dict(size=32, color="#1f77b4")),
368
- scene=dict(
369
- camera=dict(eye=eye, projection=dict(type="perspective")),
370
- xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"),
371
- yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"),
372
- zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"),
373
- ),
374
- margin=dict(l=0, r=0, b=0, t=60),
375
- uirevision="keep", # keep interactions across reruns
376
- )
377
- )
378
- else:
379
- fig = go.Figure(
380
- data=[go.Scatter(
381
- x=coords[:, 0], y=coords[:, 1],
382
- mode="markers+text", text=words, textposition="top center",
383
- marker=dict(size=10),
384
- )],
385
- layout=go.Layout(
386
- title=dict(text=title_html, x=0.5, xanchor="center", yanchor="top",
387
- font=dict(size=32, color="#1f77b4")),
388
- xaxis=dict(title="PC1"),
389
- yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1),
390
- margin=dict(l=0, r=0, b=0, t=60),
391
- )
392
- )
393
 
394
- # Use a placeholder so we can update the figure in-place during rotation
395
- ph = st.empty()
396
 
397
- # First render
398
- if HAVE_EVENTS and st.session_state.proj_mode == "3D":
399
- # Capture relayout to update live camera if user moved it before hitting rotate
400
- _ = plotly_events(
401
- fig, events=["relayout"], select_event=False, click_event=False, hover_event=False,
402
- override_width="100%", override_height=None, key="plot3d"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  )
404
- # Update stored camera from relayout payload if present
405
- rel_key = "plotly_relayout-data-plot3d"
406
- rel = st.session_state.get(rel_key)
407
- if isinstance(rel, dict):
408
- cam_update = rel.get("scene.camera") or {}
409
- if "eye" in cam_update:
410
- st.session_state.live_camera = {
411
- "eye": cam_update["eye"],
412
- "projection": {"type": "perspective"}
413
- }
414
- else:
415
- # Sometimes relayout sends separate eye components
416
- keys = ["scene.camera.eye.x", "scene.camera.eye.y", "scene.camera.eye.z"]
417
- if all(k in rel for k in keys):
418
- st.session_state.live_camera = {
419
- "eye": {"x": rel[keys[0]], "y": rel[keys[1]], "z": rel[keys[2]]},
420
- "projection": {"type": "perspective"}
421
- }
422
- else:
423
- ph.plotly_chart(fig, use_container_width=True)
424
-
425
- # Continuous rotation loop (only in 3D and when rotate_on is True)
426
- if st.session_state.proj_mode == "3D" and st.session_state.rotate_on:
427
- # Start from the CURRENT live camera (which includes any manual rotation)
428
- eye = st.session_state.live_camera["eye"]
429
- r = max(1e-6, float(np.sqrt(eye["x"]**2 + eye["y"]**2)))
430
- z_eye = float(eye["z"])
431
- # Current angle in radians based on x,y
432
- angle = float(np.arctan2(eye["y"], eye["x"]))
433
-
434
- # Render a batch of frames, then rerun to keep UI responsive
435
- frames = 180 # ~6 seconds at 33ms per frame
436
- step = np.deg2rad(speed)
437
-
438
- for _ in range(frames):
439
- # Update angle and camera eye
440
- angle += step
441
- new_eye = {"x": r * np.cos(angle), "y": r * np.sin(angle), "z": z_eye}
442
- st.session_state.live_camera = {"eye": new_eye, "projection": {"type": "perspective"}}
443
- fig.update_layout(scene_camera=st.session_state.live_camera)
444
-
445
- # Render updated frame
446
- ph.plotly_chart(fig, use_container_width=True)
447
- time.sleep(0.033) # ~30 FPS
448
 
449
- # If user toggled rotation off during this batch, break early
450
- if not st.session_state.rotate_on:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  break
 
 
 
 
 
 
452
 
453
- # If still on after this batch, trigger another rerun to continue rotating
454
- if st.session_state.rotate_on:
455
  st.rerun()
456
 
457
- # -----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  # ROUTER
459
- # -----------------------
460
  if page == "info":
461
- show_info()
462
  else:
463
- show_demo()
 
1
  # app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import time
3
  import random
4
  import numpy as np
 
8
  import torch
9
  from transformers import AutoTokenizer, AutoModel
10
 
 
 
 
 
 
 
 
 
11
  st.set_page_config(page_title="Embedding Demo", layout="wide")
12
 
13
+ # ----------------------------
14
+ # BASE DATASETS (lowercase)
15
+ # ----------------------------
16
  DATASETS = {
17
+ "countries": [
18
+ "germany","france","italy","spain","portugal","poland","netherlands","belgium",
19
+ "austria","switzerland","greece","norway","sweden","finland","denmark","ireland",
20
+ "hungary","czechia","slovakia","slovenia","iceland","estonia","latvia","lithuania","romania"
21
+ ],
22
+ "animals": [
23
+ "cat","dog","lion","tiger","bear","wolf","fox","eagle","shark","whale",
24
+ "zebra","giraffe","elephant","hippopotamus","rhinoceros","kangaroo","panda","otter","seal","dolphin",
25
+ "chimpanzee","gorilla","leopard","cheetah","lynx"
26
+ ],
27
+ "furniture": [
28
+ "armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser",
29
+ "nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner",
30
+ "ottoman","console table","vanity","buffet","sectional sofa"
31
+ ],
32
+ "actors": [
33
+ "brad pitt","angelina jolie","meryl streep","leonardo dicaprio","tom hanks","scarlett johansson","robert de niro",
34
+ "natalie portman","matt damon","cate blanchett","johnny depp","keanu reeves","hugh jackman","emma stone","ryan gosling",
35
+ "jennifer lawrence","christian bale","charlize theron","will smith","anne hathaway","denzel washington","morgan freeman",
36
+ "julia roberts","george clooney","kate winslet"
37
+ ],
38
+ "rock group": [
39
+ "the beatles","rolling stones","pink floyd","queen","led zeppelin","u2","ac/dc","nirvana","radiohead","metallica",
40
+ "guns n' roses","red hot chili peppers","coldplay","pearl jam","the police","aerosmith","green day","foo fighters",
41
+ "the doors","bon jovi","deep purple","the who","the kinks","fleetwood mac","the beach boys"
42
+ ],
43
+ "sports": [
44
+ "soccer","basketball","tennis","baseball","golf","swimming","cycling","running","volleyball","rugby",
45
+ "boxing","skiing","snowboarding","surfing","skateboarding","karate","judo","fencing","rowing","badminton",
46
+ "cricket","table tennis","gymnastics","hockey","climbing"
47
+ ]
48
  }
49
 
50
+ # ----------------------------
51
+ # RANDOM MIXED SETS (once per session)
52
+ # ----------------------------
53
+ def make_random_mixed_sets(base: dict, n_sets: int = 3) -> dict:
54
+ keys = list(base.keys())
 
 
55
  mixed = {}
56
  for _ in range(n_sets):
57
+ sources = random.sample(keys, 3)
58
  items = []
59
  for s in sources:
60
+ take = min(7, len(base[s]))
61
+ items.extend(random.sample(base[s], take))
62
+ mixed_name = "/".join(sources).lower()
63
+ mixed[mixed_name] = items[:21]
 
64
  return mixed
65
 
66
+ if "mixed_added" not in st.session_state:
67
+ DATASETS.update(make_random_mixed_sets(DATASETS, 3))
68
+ st.session_state.mixed_added = True
 
69
 
70
+ # ----------------------------
71
  # MODELS (transformers)
72
+ # ----------------------------
73
  EMBED_MODELS = {
74
+ "all-minilm-l6-v2 (384d)": "sentence-transformers/all-MiniLM-L6-v2",
75
  "all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2",
76
  "all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1",
77
  }
 
90
  with torch.no_grad():
91
  inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
92
  outputs = model(**inputs)
 
93
  token_embeddings = outputs.last_hidden_state # (B,T,H)
94
  mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings)
95
  summed = (token_embeddings * mask).sum(dim=1)
96
  counts = mask.sum(dim=1).clamp(min=1e-9)
97
+ embeddings = summed / counts # mean pooling
98
  return embeddings.cpu().numpy()
99
 
100
+ # ----------------------------
101
+ # STATE: camera + rotation
102
+ # ----------------------------
103
+ if "camera_eye" not in st.session_state:
104
+ st.session_state.camera_eye = {"x": 1.6, "y": 1.6, "z": 1.2}
105
+ if "spinning" not in st.session_state:
106
+ st.session_state.spinning = False
107
+ if "angle_rad" not in st.session_state:
108
+ # derive from eye.x, eye.y
109
+ e = st.session_state.camera_eye
110
+ st.session_state.angle_rad = float(np.arctan2(e["y"], e["x"]))
111
+
112
+ def update_eye_from_angle(angle_rad: float, radius: float, z: float):
113
+ return {"x": radius * np.cos(angle_rad), "y": radius * np.sin(angle_rad), "z": z}
 
 
 
 
 
 
 
 
 
114
 
115
+ # ----------------------------
116
+ # NAVIGATION via st.query_params
117
+ # ----------------------------
118
+ def goto(page: str):
119
+ st.query_params["page"] = page
120
  st.rerun()
121
 
122
+ page = st.query_params.get("page", ["demo"])[0]
123
+
124
+ # ----------------------------
125
  # INFO PAGE
126
+ # ----------------------------
127
+ def info_page():
128
+ st.title("ℹ about this demo")
129
+ st.write("""
130
+ **embeddings** turn words (or longer text) into numerical vectors.
131
+ in this vector space, **semantically related** items end up **near** each other.
132
+
133
+ why this is useful:
134
+ - semantic search and retrieval
135
+ - clustering and topic discovery
136
+ - recommendations and deduplication
137
  - measuring similarity and analogies
138
 
139
+ this demo embeds single words with a selectable model, reduces to 2d/3d with pca,
140
  and shows how related words cluster in the projected space.
141
+ """.strip())
142
+ if st.button("⬅ back to demo"):
143
+ goto("demo")
 
144
 
145
+ # ----------------------------
146
  # DEMO PAGE
147
+ # ----------------------------
148
+ def demo_page():
149
+ # top row: dataset, model + 2d/3d, info button
150
+ c1, c2, c3 = st.columns([2, 2, 1])
151
+ with c1:
152
+ ds_names = list(DATASETS.keys())
153
+ dataset_name = st.selectbox("dataset", ds_names, index=ds_names.index("furniture") if "furniture" in ds_names else 0)
154
+ with c2:
155
+ cc1, cc2 = st.columns([2, 1])
156
+ with cc1:
157
+ model_label = st.selectbox("embedding model", list(EMBED_MODELS.keys()))
158
+ model_name = EMBED_MODELS[model_label]
159
+ with cc2:
160
+ proj_mode = st.radio("projection", ["2d", "3d"], horizontal=True)
161
+ with c3:
162
+ if st.button("ℹ info"):
163
+ goto("info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ words = DATASETS[dataset_name]
166
+ st.text_area("dataset words", "\n".join(words), height=160)
167
 
168
+ # Embed + PCA
169
+ embs = embed_texts(model_name, tuple(words))
170
+ if proj_mode == "2d":
171
+ coords = PCA(n_components=2).fit_transform(embs)
172
+ else:
173
+ coords = PCA(n_components=3).fit_transform(embs)
174
+
175
+ title_html = f"<b style='color:#1f77b4; font-size:2.0rem;'>{dataset_name}</b>"
176
+
177
+ if proj_mode == "3d":
178
+ # compute radius from current eye, keep same z
179
+ eye = st.session_state.camera_eye
180
+ radius = float(np.sqrt(eye["x"]**2 + eye["y"]**2)) or 1.6
181
+ z_eye = float(eye["z"])
182
+
183
+ fig = go.Figure(
184
+ data=[go.Scatter3d(
185
+ x=coords[:, 0], y=coords[:, 1], z=coords[:, 2],
186
+ mode="markers+text", text=words, textposition="top center",
187
+ marker=dict(size=6),
188
+ )],
189
+ layout=go.Layout(
190
+ title=dict(text=title_html, x=0.5, xanchor="center", yanchor="top",
191
+ font=dict(size=30, color="#1f77b4")),
192
+ scene=dict(
193
+ camera=dict(eye=eye, projection=dict(type="perspective")),
194
+ xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"),
195
+ yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"),
196
+ zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"),
197
+ ),
198
+ margin=dict(l=0, r=0, b=0, t=60),
199
+ uirevision="keep",
200
  )
201
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ # Controls under the plot: Start/Stop rotation
204
+ b1, b2 = st.columns([1, 1])
205
+ with b1:
206
+ start_clicked = st.button("▶ start rotation", disabled=st.session_state.spinning)
207
+ with b2:
208
+ stop_clicked = st.button("⏹ stop rotation", disabled=not st.session_state.spinning)
209
+
210
+ # If start pressed: turn on spinner and initialize angle from current stored eye
211
+ if start_clicked:
212
+ st.session_state.spinning = True
213
+ # start from stored angle (not capturing manual camera — simple approach)
214
+ st.session_state.angle_rad = float(np.arctan2(eye["y"], eye["x"]))
215
+ # fall through to loop below (this turn will render once, then continue)
216
+
217
+ # If stop pressed: turn off spinner (and keep stop disabled after)
218
+ if stop_clicked:
219
+ st.session_state.spinning = False
220
+
221
+ # Live render placeholder
222
+ placeholder = st.empty()
223
+
224
+ # First draw (static) before any loop
225
+ placeholder.plotly_chart(fig, use_container_width=True)
226
+
227
+ # Continuous rotation loop while spinning
228
+ if st.session_state.spinning:
229
+ # one "batch" of frames, then rerun to keep UI responsive
230
+ steps_per_batch = 120
231
+ step = np.deg2rad(3) # 3 degrees per frame ~ smooth
232
+ for _ in range(steps_per_batch):
233
+ if not st.session_state.spinning:
234
  break
235
+ st.session_state.angle_rad += step
236
+ new_eye = update_eye_from_angle(st.session_state.angle_rad, radius, z_eye)
237
+ st.session_state.camera_eye = new_eye
238
+ fig.update_layout(scene_camera=dict(eye=new_eye, projection=dict(type="perspective")))
239
+ placeholder.plotly_chart(fig, use_container_width=True)
240
+ time.sleep(0.033) # ~30 FPS
241
 
242
+ # If still spinning after this batch, rerun to keep going
243
+ if st.session_state.spinning:
244
  st.rerun()
245
 
246
+ else:
247
+ fig = go.Figure(
248
+ data=[go.Scatter(
249
+ x=coords[:, 0], y=coords[:, 1],
250
+ mode="markers+text", text=words, textposition="top center",
251
+ marker=dict(size=9),
252
+ )],
253
+ layout=go.Layout(
254
+ title=dict(text=title_html, x=0.5, xanchor="center", yanchor="top",
255
+ font=dict(size=30, color="#1f77b4")),
256
+ xaxis=dict(title="PC1"),
257
+ yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1),
258
+ margin=dict(l=0, r=0, b=0, t=60),
259
+ )
260
+ )
261
+ st.plotly_chart(fig, use_container_width=True)
262
+
263
+ # ----------------------------
264
  # ROUTER
265
+ # ----------------------------
266
  if page == "info":
267
+ info_page()
268
  else:
269
+ demo_page()