chyams Claude Opus 4.6 commited on
Commit
d07fd1f
Β·
1 Parent(s): 4ba6161

Embedding Explorer: annotations fix, MDS normalization, light mode, design polish

Browse files

- Switch from Scatter3d text to scene annotations (fixes Plotly WebGL text sizing bug)
- Normalize MDS output to [-1,1] with fixed axis ranges for consistent label sizes
- Force light mode via head script redirect
- Dark purple buttons (#63348d), light purple block backgrounds (#f3f0f7)
- Remove button-like styling from Gradio labels
- Per-word palette colors with lighter neighbor shades
- Add Lecture 7 design discussion (discussion.md)
- Add Embedding Explorer section to tools/CLAUDE.md

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +253 -108
app.py CHANGED
@@ -19,7 +19,6 @@ import re
19
 
20
  import numpy as np
21
  import plotly.graph_objects as go
22
- from sklearn.decomposition import PCA
23
  import gradio as gr
24
 
25
  # ── Configuration (all changeable via HF Space env vars) ─────
@@ -49,6 +48,32 @@ DARK = "#1a1a2e"
49
  GRAY = "#888888"
50
  BG = "#fafafa"
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # ── Load GloVe embeddings on startup ─────────────────────────
53
 
54
  import time
@@ -114,32 +139,34 @@ def parse_expression(expr):
114
  return pos, neg, ordered
115
 
116
 
117
- def fit_pca_3d(vectors):
118
- """PCA β†’ 3D. Returns (coords, fitted_pca)."""
 
 
119
  n = len(vectors)
120
- nc = min(3, n, vectors.shape[1])
121
- pca = PCA(n_components=nc)
122
- coords = pca.fit_transform(vectors)
 
 
 
 
123
  if nc < 3:
124
  coords = np.hstack([coords, np.zeros((n, 3 - nc))])
125
- return coords, pca
126
-
127
-
128
- def project_3d(pca, vectors):
129
- """Project new vectors into an existing PCA space."""
130
- coords = pca.transform(vectors)
131
- if coords.shape[1] < 3:
132
- coords = np.hstack([coords, np.zeros((len(vectors), 3 - coords.shape[1]))])
133
  return coords
134
 
135
 
136
  def _axis(title=""):
137
- """Subtle 3D axis with gridlines for depth perception."""
138
  return dict(
139
  showgrid=True,
140
- gridcolor="rgba(99,52,141,0.08)", # very faint purple grid
141
  zeroline=True,
142
- zerolinecolor="rgba(99,52,141,0.25)",
143
  zerolinewidth=2,
144
  showticklabels=False,
145
  title=title,
@@ -148,13 +175,19 @@ def _axis(title=""):
148
  )
149
 
150
 
151
- def layout_3d(height=560):
152
- """Shared Plotly 3D layout."""
 
 
 
 
 
 
153
  return dict(
154
  scene=dict(
155
- xaxis=_axis(),
156
- yaxis=_axis(),
157
- zaxis=_axis(),
158
  bgcolor=BG,
159
  camera=dict(eye=dict(x=1.5, y=1.5, z=1.2)),
160
  aspectmode="cube",
@@ -163,7 +196,7 @@ def layout_3d(height=560):
163
  margin=dict(l=0, r=0, t=10, b=10),
164
  showlegend=True,
165
  legend=dict(
166
- yanchor="top", y=0.99, xanchor="left", x=0.01,
167
  bgcolor="rgba(255,255,255,0.85)",
168
  font=dict(family="Inter, sans-serif", size=12),
169
  ),
@@ -172,15 +205,18 @@ def layout_3d(height=560):
172
  )
173
 
174
 
175
- def add_vectors_from_origin(fig, coords, labels, color=PURPLE, width=3):
 
176
  """Draw vector lines from origin to each point (words ARE vectors)."""
177
  for i, label in enumerate(labels):
 
 
178
  fig.add_trace(go.Scatter3d(
179
  x=[0, coords[i, 0]],
180
  y=[0, coords[i, 1]],
181
  z=[0, coords[i, 2]],
182
  mode="lines",
183
- line=dict(color=color, width=width),
184
  showlegend=False, hoverinfo="none",
185
  ))
186
  # Origin marker
@@ -232,77 +268,141 @@ def explore(words_text, selected):
232
 
233
  valid = valid[:12] # cap to keep plot readable
234
 
235
- # Main word vectors β€” PCA fitted on these only (stable positions)
236
- vecs = np.array([model[w] for w in valid])
237
- main_coords, pca = fit_pca_3d(vecs)
238
-
239
- # Neighbors (projected into the same PCA space)
240
- nbr_data, nbr_coords = [], None
241
  if selected and selected != "(clear)" and selected in VOCAB and selected in valid:
242
  nbrs = model.most_similar(selected, topn=N_NEIGHBORS)
243
  nbr_data = [(w, s) for w, s in nbrs if w not in valid]
244
- if nbr_data:
245
- nv = np.array([model[w] for w, _ in nbr_data])
246
- nbr_coords = project_3d(pca, nv)
247
  else:
248
  selected = None
249
 
 
 
 
 
 
 
 
 
 
 
 
250
  # ── Build figure ──
251
  fig = go.Figure()
 
 
 
 
 
 
 
 
 
252
 
253
- # Vector lines from origin to each word
254
- add_vectors_from_origin(fig, main_coords, valid)
255
-
256
- # Main words (markers + labels at tips of vectors)
257
- fig.add_trace(go.Scatter3d(
258
- x=main_coords[:, 0].tolist(),
259
- y=main_coords[:, 1].tolist(),
260
- z=main_coords[:, 2].tolist(),
261
- mode="markers+text",
262
- text=valid,
263
- textposition="top center",
264
- textfont=dict(size=14, color=DARK),
265
- marker=dict(
266
- size=9, color=PURPLE, opacity=0.9,
267
- line=dict(width=1, color="white"),
268
- ),
269
- name="Words",
270
- hoverinfo="text",
271
- hovertext=valid,
272
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- # Neighbors + connection lines
275
- if nbr_data and nbr_coords is not None:
276
  fig.add_trace(go.Scatter3d(
277
- x=nbr_coords[:, 0].tolist(),
278
- y=nbr_coords[:, 1].tolist(),
279
- z=nbr_coords[:, 2].tolist(),
280
- mode="markers+text",
281
- text=[w for w, _ in nbr_data],
282
- textposition="top center",
283
- textfont=dict(size=11, color=GRAY),
284
  marker=dict(
285
- size=5, color=PURPLE_LIGHT, opacity=0.8,
286
- line=dict(width=1, color=PURPLE),
287
  ),
288
- name=f'Near "{selected}"',
289
  hoverinfo="text",
290
- hovertext=[f"{w} ({s:.3f})" for w, s in nbr_data],
291
  ))
 
 
292
 
293
- # Dotted lines from selected word to its neighbors
294
- si = valid.index(selected)
295
- for i in range(len(nbr_data)):
296
  fig.add_trace(go.Scatter3d(
297
- x=[main_coords[si, 0], nbr_coords[i, 0]],
298
- y=[main_coords[si, 1], nbr_coords[i, 1]],
299
- z=[main_coords[si, 2], nbr_coords[i, 2]],
300
- mode="lines",
301
- line=dict(color=PURPLE_LIGHT, width=2, dash="dot"),
302
- showlegend=False, hoverinfo="none",
 
 
 
 
 
303
  ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- fig.update_layout(**layout_3d())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  # Status
308
  status = f"**{len(valid)} words** in 3D"
@@ -358,14 +458,14 @@ def arithmetic(expression):
358
  for w in neg:
359
  rv -= model[w]
360
 
361
- # Collect all words for PCA
362
  result_words = [w for w, _ in result_entries]
363
  all_words = operands + result_words
364
  all_vecs = np.array([model[w] for w in all_words])
365
 
366
- # Include result vector in PCA fit for best view
367
  combined = np.vstack([all_vecs, rv.reshape(1, -1)])
368
- coords_all, _ = fit_pca_3d(combined)
369
 
370
  rv_coord = coords_all[-1]
371
  coords = coords_all[:-1]
@@ -409,7 +509,7 @@ def arithmetic(expression):
409
  mode="markers+text",
410
  text=[operands[i] for i in pi],
411
  textposition="top center",
412
- textfont=dict(size=14, color=DARK),
413
  marker=dict(size=10, color=PURPLE, opacity=0.9),
414
  name="Positive (+)",
415
  hoverinfo="text",
@@ -426,7 +526,7 @@ def arithmetic(expression):
426
  mode="markers+text",
427
  text=[operands[i] for i in ni],
428
  textposition="top center",
429
- textfont=dict(size=14, color=DARK),
430
  marker=dict(size=10, color=PINK, opacity=0.9),
431
  name="Negative (βˆ’)",
432
  hoverinfo="text",
@@ -440,7 +540,7 @@ def arithmetic(expression):
440
  mode="markers+text",
441
  text=[f"β‰ˆ {top_result[0]}"],
442
  textposition="top center",
443
- textfont=dict(size=16, color=PURPLE),
444
  marker=dict(
445
  size=14, color=GOLD, opacity=1.0, symbol="diamond",
446
  line=dict(width=2, color=PURPLE),
@@ -460,8 +560,8 @@ def arithmetic(expression):
460
  mode="markers+text",
461
  text=[f"{w} ({s:.2f})" for w, s in other_results],
462
  textposition="top center",
463
- textfont=dict(size=11, color=GRAY),
464
- marker=dict(size=5, color=PURPLE_LIGHT, opacity=0.7),
465
  name="Other matches",
466
  hoverinfo="text",
467
  hovertext=[f"{w} ({s:.3f})" for w, s in other_results],
@@ -514,13 +614,58 @@ def arithmetic(expression):
514
  CSS = """
515
  .gradio-container { max-width: 1200px !important; }
516
  h1 { color: #63348d !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  """
518
 
 
 
519
  with gr.Blocks(
520
  title="Embedding Explorer",
 
521
  theme=gr.themes.Soft(
522
  primary_hue="purple",
523
  font=gr.themes.GoogleFont("Inter"),
 
 
 
 
 
 
 
524
  ),
525
  css=CSS,
526
  ) as demo:
@@ -543,27 +688,27 @@ with gr.Blocks(
543
  "Similar words cluster together. "
544
  "Click a word below the plot to see its nearest neighbors.*"
545
  )
 
546
  with gr.Row():
547
- with gr.Column(scale=1):
548
  exp_in = gr.Textbox(
549
  label="Words (space-separated, min 3)",
550
  placeholder="dog cat fish car truck",
551
  lines=1,
552
  )
 
553
  exp_btn = gr.Button("Show in 3D", variant="primary")
554
- exp_status = gr.Markdown("")
555
- exp_radio = gr.Radio(
556
- label="Click to see nearest neighbors",
557
- choices=[], value=None,
558
- visible=False, interactive=True,
559
- )
560
- gr.Examples(
561
- examples=[[e] for e in EXPLORE_EXAMPLES],
562
- inputs=[exp_in],
563
- label="Try these",
564
- )
565
- with gr.Column(scale=2):
566
- exp_plot = gr.Plot(label="Embedding Space")
567
 
568
  # Events
569
  exp_btn.click(
@@ -588,22 +733,22 @@ with gr.Blocks(
588
  "*Word vectors can do math! "
589
  "The results reveal hidden relationships between words.*"
590
  )
 
591
  with gr.Row():
592
- with gr.Column(scale=1):
593
  math_in = gr.Textbox(
594
  label="Expression",
595
  placeholder="king - man + woman",
596
  lines=1,
597
  )
 
598
  math_btn = gr.Button("Compute", variant="primary")
599
- math_status = gr.Markdown("")
600
- gr.Examples(
601
- examples=[[e] for e in ARITHMETIC_EXAMPLES],
602
- inputs=[math_in],
603
- label="Try these",
604
- )
605
- with gr.Column(scale=2):
606
- math_plot = gr.Plot(label="Vector Arithmetic")
607
 
608
  # Events
609
  math_btn.click(
 
19
 
20
  import numpy as np
21
  import plotly.graph_objects as go
 
22
  import gradio as gr
23
 
24
  # ── Configuration (all changeable via HF Space env vars) ─────
 
48
  GRAY = "#888888"
49
  BG = "#fafafa"
50
 
51
+ # Categorical palette for up to 12 words (design-system first, then complements)
52
+ PALETTE = [
53
+ "#63348d", # purple (design primary)
54
+ "#2e86c1", # blue
55
+ "#c0392b", # red
56
+ "#f0c040", # gold (design)
57
+ "#27ae60", # green
58
+ "#de95a0", # pink (design)
59
+ "#e67e22", # orange
60
+ "#1abc9c", # teal
61
+ "#8e44ad", # violet
62
+ "#d4ac0d", # dark gold
63
+ "#2980b9", # steel blue
64
+ "#c0392b", # crimson
65
+ ]
66
+
67
+
68
+ def lighten(hex_color, amount=0.3):
69
+ """Lighten a hex color by blending toward white. amount=0.3 β†’ 30% lighter."""
70
+ h = hex_color.lstrip("#")
71
+ r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
72
+ r = int(r + (255 - r) * amount)
73
+ g = int(g + (255 - g) * amount)
74
+ b = int(b + (255 - b) * amount)
75
+ return f"#{r:02x}{g:02x}{b:02x}"
76
+
77
  # ── Load GloVe embeddings on startup ─────────────────────────
78
 
79
  import time
 
139
  return pos, neg, ordered
140
 
141
 
142
+ def reduce_3d(vectors):
143
+ """MDS (cosine distance) β†’ 3D. Normalizes to [-1, 1] for consistent label sizing."""
144
+ from sklearn.manifold import MDS
145
+ from sklearn.metrics.pairwise import cosine_distances
146
  n = len(vectors)
147
+ if n < 2:
148
+ return np.zeros((n, 3))
149
+ dist = cosine_distances(vectors)
150
+ nc = min(3, n)
151
+ mds = MDS(n_components=nc, dissimilarity="precomputed",
152
+ random_state=42, normalized_stress="auto", max_iter=300)
153
+ coords = mds.fit_transform(dist)
154
  if nc < 3:
155
  coords = np.hstack([coords, np.zeros((n, 3 - nc))])
156
+ # Normalize to [-1, 1] so axis ranges and label sizes are consistent
157
+ max_abs = np.abs(coords).max()
158
+ if max_abs > 1e-8:
159
+ coords = coords / max_abs
 
 
 
 
160
  return coords
161
 
162
 
163
  def _axis(title=""):
164
+ """3D axis with visible gridlines for depth perception."""
165
  return dict(
166
  showgrid=True,
167
+ gridcolor="rgba(99,52,141,0.18)",
168
  zeroline=True,
169
+ zerolinecolor="rgba(99,52,141,0.40)",
170
  zerolinewidth=2,
171
  showticklabels=False,
172
  title=title,
 
175
  )
176
 
177
 
178
+ def layout_3d(height=700):
179
+ """Shared Plotly 3D layout. Fixed [-1.3, 1.3] range; preserves camera between updates."""
180
+ ax_x, ax_y, ax_z = _axis(), _axis(), _axis()
181
+ # Fixed range since reduce_3d normalizes to [-1, 1]; padding for labels
182
+ fixed = [-1.3, 1.3]
183
+ ax_x["range"] = fixed
184
+ ax_y["range"] = fixed
185
+ ax_z["range"] = fixed
186
  return dict(
187
  scene=dict(
188
+ xaxis=ax_x,
189
+ yaxis=ax_y,
190
+ zaxis=ax_z,
191
  bgcolor=BG,
192
  camera=dict(eye=dict(x=1.5, y=1.5, z=1.2)),
193
  aspectmode="cube",
 
196
  margin=dict(l=0, r=0, t=10, b=10),
197
  showlegend=True,
198
  legend=dict(
199
+ yanchor="top", y=0.99, xanchor="right", x=0.99,
200
  bgcolor="rgba(255,255,255,0.85)",
201
  font=dict(family="Inter, sans-serif", size=12),
202
  ),
 
205
  )
206
 
207
 
208
+ def add_vectors_from_origin(fig, coords, labels, colors=None, widths=None,
209
+ default_color=PURPLE, default_width=3):
210
  """Draw vector lines from origin to each point (words ARE vectors)."""
211
  for i, label in enumerate(labels):
212
+ c = colors[i] if colors else default_color
213
+ w = widths[i] if widths else default_width
214
  fig.add_trace(go.Scatter3d(
215
  x=[0, coords[i, 0]],
216
  y=[0, coords[i, 1]],
217
  z=[0, coords[i, 2]],
218
  mode="lines",
219
+ line=dict(color=c, width=w),
220
  showlegend=False, hoverinfo="none",
221
  ))
222
  # Origin marker
 
268
 
269
  valid = valid[:12] # cap to keep plot readable
270
 
271
+ # Gather neighbors first so we can MDS everything together
272
+ nbr_data = []
 
 
 
 
273
  if selected and selected != "(clear)" and selected in VOCAB and selected in valid:
274
  nbrs = model.most_similar(selected, topn=N_NEIGHBORS)
275
  nbr_data = [(w, s) for w, s in nbrs if w not in valid]
 
 
 
276
  else:
277
  selected = None
278
 
279
+ # Build combined word list and run MDS once
280
+ all_words = valid + [w for w, _ in nbr_data]
281
+ all_vecs = np.array([model[w] for w in all_words])
282
+ all_coords = reduce_3d(all_vecs)
283
+
284
+ main_coords = all_coords[:len(valid)]
285
+ nbr_coords = all_coords[len(valid):] if nbr_data else None
286
+
287
+ # ── Assign a distinct color to each word from the palette ──
288
+ word_colors = {w: PALETTE[i % len(PALETTE)] for i, w in enumerate(valid)}
289
+
290
  # ── Build figure ──
291
  fig = go.Figure()
292
+ # Collect 3D annotations (HTML overlays β€” immune to WebGL text sizing bug)
293
+ annotations = []
294
+
295
+ def add_label(x, y, z, text, size=16, color=DARK, opacity=1.0):
296
+ annotations.append(dict(
297
+ x=x, y=y, z=z, text=text, showarrow=False,
298
+ font=dict(size=size, color=color, family="Inter, sans-serif"),
299
+ opacity=opacity, yshift=18,
300
+ ))
301
 
302
+ if selected:
303
+ si = valid.index(selected)
304
+ sel_color = word_colors[selected]
305
+ nbr_color = lighten(sel_color, 0.3)
306
+
307
+ # Vector lines β€” selected word's line is bold, others are faded
308
+ vec_colors = [sel_color if w == selected else lighten(word_colors[w], 0.5)
309
+ for w in valid]
310
+ vec_widths = [4 if w == selected else 1.5 for w in valid]
311
+ add_vectors_from_origin(fig, main_coords, valid,
312
+ colors=vec_colors, widths=vec_widths)
313
+
314
+ # Dimmed words (not selected) β€” markers only
315
+ dim_idx = [i for i in range(len(valid)) if i != si]
316
+ if dim_idx:
317
+ fig.add_trace(go.Scatter3d(
318
+ x=[main_coords[i, 0] for i in dim_idx],
319
+ y=[main_coords[i, 1] for i in dim_idx],
320
+ z=[main_coords[i, 2] for i in dim_idx],
321
+ mode="markers",
322
+ marker=dict(
323
+ size=8,
324
+ color=[lighten(word_colors[valid[i]], 0.4) for i in dim_idx],
325
+ opacity=0.5,
326
+ line=dict(width=1, color="white"),
327
+ ),
328
+ name="Words",
329
+ hoverinfo="text",
330
+ hovertext=[valid[i] for i in dim_idx],
331
+ ))
332
+ for i in dim_idx:
333
+ add_label(main_coords[i, 0], main_coords[i, 1], main_coords[i, 2],
334
+ valid[i], size=15,
335
+ color=lighten(word_colors[valid[i]], 0.4), opacity=0.7)
336
 
337
+ # Selected word β€” bright, full color, markers only
 
338
  fig.add_trace(go.Scatter3d(
339
+ x=[main_coords[si, 0]],
340
+ y=[main_coords[si, 1]],
341
+ z=[main_coords[si, 2]],
342
+ mode="markers",
 
 
 
343
  marker=dict(
344
+ size=14, color=sel_color, opacity=1.0,
345
+ line=dict(width=2, color="white"),
346
  ),
347
+ name=f"Selected: {selected}",
348
  hoverinfo="text",
349
+ hovertext=[f"{selected} (selected)"],
350
  ))
351
+ add_label(main_coords[si, 0], main_coords[si, 1], main_coords[si, 2],
352
+ f"<b>{selected}</b>", size=18, color=DARK)
353
 
354
+ # Neighbors β€” lighter shade of the selected word's color
355
+ if nbr_data and nbr_coords is not None:
 
356
  fig.add_trace(go.Scatter3d(
357
+ x=nbr_coords[:, 0].tolist(),
358
+ y=nbr_coords[:, 1].tolist(),
359
+ z=nbr_coords[:, 2].tolist(),
360
+ mode="markers",
361
+ marker=dict(
362
+ size=9, color=nbr_color, opacity=0.9,
363
+ line=dict(width=1, color=sel_color),
364
+ ),
365
+ name=f'Near "{selected}"',
366
+ hoverinfo="text",
367
+ hovertext=[f"{w} ({s:.3f})" for w, s in nbr_data],
368
  ))
369
+ for i, (w, s) in enumerate(nbr_data):
370
+ add_label(nbr_coords[i, 0], nbr_coords[i, 1], nbr_coords[i, 2],
371
+ w, size=16, color=DARK)
372
+
373
+ # Dotted lines from selected word to its neighbors
374
+ for i in range(len(nbr_data)):
375
+ fig.add_trace(go.Scatter3d(
376
+ x=[main_coords[si, 0], nbr_coords[i, 0]],
377
+ y=[main_coords[si, 1], nbr_coords[i, 1]],
378
+ z=[main_coords[si, 2], nbr_coords[i, 2]],
379
+ mode="lines",
380
+ line=dict(color=nbr_color, width=2, dash="dot"),
381
+ showlegend=False, hoverinfo="none",
382
+ ))
383
+ else:
384
+ # No selection β€” each word gets its own color, markers only
385
+ colors = [word_colors[w] for w in valid]
386
+ add_vectors_from_origin(fig, main_coords, valid, colors=colors)
387
 
388
+ fig.add_trace(go.Scatter3d(
389
+ x=main_coords[:, 0].tolist(),
390
+ y=main_coords[:, 1].tolist(),
391
+ z=main_coords[:, 2].tolist(),
392
+ mode="markers",
393
+ marker=dict(
394
+ size=10, color=colors, opacity=0.9,
395
+ line=dict(width=1, color="white"),
396
+ ),
397
+ name="Words",
398
+ hoverinfo="text",
399
+ hovertext=valid,
400
+ ))
401
+ for i, w in enumerate(valid):
402
+ add_label(main_coords[i, 0], main_coords[i, 1], main_coords[i, 2],
403
+ w, size=16, color=DARK)
404
+
405
+ fig.update_layout(**layout_3d(), scene_annotations=annotations)
406
 
407
  # Status
408
  status = f"**{len(valid)} words** in 3D"
 
458
  for w in neg:
459
  rv -= model[w]
460
 
461
+ # Collect all words for MDS
462
  result_words = [w for w, _ in result_entries]
463
  all_words = operands + result_words
464
  all_vecs = np.array([model[w] for w in all_words])
465
 
466
+ # Include result vector in MDS fit for best view
467
  combined = np.vstack([all_vecs, rv.reshape(1, -1)])
468
+ coords_all = reduce_3d(combined)
469
 
470
  rv_coord = coords_all[-1]
471
  coords = coords_all[:-1]
 
509
  mode="markers+text",
510
  text=[operands[i] for i in pi],
511
  textposition="top center",
512
+ textfont=dict(size=18, color=DARK),
513
  marker=dict(size=10, color=PURPLE, opacity=0.9),
514
  name="Positive (+)",
515
  hoverinfo="text",
 
526
  mode="markers+text",
527
  text=[operands[i] for i in ni],
528
  textposition="top center",
529
+ textfont=dict(size=18, color=DARK),
530
  marker=dict(size=10, color=PINK, opacity=0.9),
531
  name="Negative (βˆ’)",
532
  hoverinfo="text",
 
540
  mode="markers+text",
541
  text=[f"β‰ˆ {top_result[0]}"],
542
  textposition="top center",
543
+ textfont=dict(size=20, color=PURPLE),
544
  marker=dict(
545
  size=14, color=GOLD, opacity=1.0, symbol="diamond",
546
  line=dict(width=2, color=PURPLE),
 
560
  mode="markers+text",
561
  text=[f"{w} ({s:.2f})" for w, s in other_results],
562
  textposition="top center",
563
+ textfont=dict(size=14, color=GRAY),
564
+ marker=dict(size=6, color=PURPLE_LIGHT, opacity=0.7),
565
  name="Other matches",
566
  hoverinfo="text",
567
  hovertext=[f"{w} ({s:.3f})" for w, s in other_results],
 
614
  CSS = """
615
  .gradio-container { max-width: 1200px !important; }
616
  h1 { color: #63348d !important; }
617
+
618
+
619
+ /* Remove button-like styling from labels */
620
+ .gradio-container label span {
621
+ border: none !important;
622
+ background: transparent !important;
623
+ box-shadow: none !important;
624
+ padding-left: 0 !important;
625
+ font-weight: 600 !important;
626
+ color: #63348d !important;
627
+ }
628
+
629
+ /* Block backgrounds β€” light purple instead of near-white */
630
+ .gradio-container .block {
631
+ background: #f3f0f7 !important;
632
+ border: 1px solid #ded9f4 !important;
633
+ }
634
+
635
+ /* Input fields should be white for contrast against purple blocks */
636
+ .gradio-container textarea,
637
+ .gradio-container input[type="text"] {
638
+ background: #ffffff !important;
639
+ }
640
+
641
+ /* Tab styling */
642
+ .gradio-container button.tab-nav-button.selected {
643
+ color: #63348d !important;
644
+ border-bottom-color: #63348d !important;
645
+ }
646
+
647
+ /* Radio items β€” less boxy */
648
+ .gradio-container .wrap .radio-group label {
649
+ border: none !important;
650
+ background: transparent !important;
651
+ }
652
  """
653
 
654
+ FORCE_LIGHT = '<script>if(!location.search.includes("__theme=light")){const u=new URL(location);u.searchParams.set("__theme","light");location.replace(u)}</script>'
655
+
656
  with gr.Blocks(
657
  title="Embedding Explorer",
658
+ head=FORCE_LIGHT,
659
  theme=gr.themes.Soft(
660
  primary_hue="purple",
661
  font=gr.themes.GoogleFont("Inter"),
662
+ ).set(
663
+ button_primary_background_fill="#63348d",
664
+ button_primary_background_fill_hover="#4a2769",
665
+ button_primary_text_color="#ffffff",
666
+ block_background_fill="#f3f0f7",
667
+ block_border_color="#ded9f4",
668
+ body_background_fill="#ffffff",
669
  ),
670
  css=CSS,
671
  ) as demo:
 
688
  "Similar words cluster together. "
689
  "Click a word below the plot to see its nearest neighbors.*"
690
  )
691
+ exp_plot = gr.Plot(label="Embedding Space")
692
  with gr.Row():
693
+ with gr.Column(scale=2):
694
  exp_in = gr.Textbox(
695
  label="Words (space-separated, min 3)",
696
  placeholder="dog cat fish car truck",
697
  lines=1,
698
  )
699
+ with gr.Column(scale=1):
700
  exp_btn = gr.Button("Show in 3D", variant="primary")
701
+ exp_status = gr.Markdown("")
702
+ exp_radio = gr.Radio(
703
+ label="Click to see nearest neighbors",
704
+ choices=[], value=None,
705
+ visible=False, interactive=True,
706
+ )
707
+ gr.Examples(
708
+ examples=[[e] for e in EXPLORE_EXAMPLES],
709
+ inputs=[exp_in],
710
+ label="Try these",
711
+ )
 
 
712
 
713
  # Events
714
  exp_btn.click(
 
733
  "*Word vectors can do math! "
734
  "The results reveal hidden relationships between words.*"
735
  )
736
+ math_plot = gr.Plot(label="Vector Arithmetic")
737
  with gr.Row():
738
+ with gr.Column(scale=2):
739
  math_in = gr.Textbox(
740
  label="Expression",
741
  placeholder="king - man + woman",
742
  lines=1,
743
  )
744
+ with gr.Column(scale=1):
745
  math_btn = gr.Button("Compute", variant="primary")
746
+ math_status = gr.Markdown("")
747
+ gr.Examples(
748
+ examples=[[e] for e in ARITHMETIC_EXAMPLES],
749
+ inputs=[math_in],
750
+ label="Try these",
751
+ )
 
 
752
 
753
  # Events
754
  math_btn.click(