Spaces:
Sleeping
Sleeping
Embedding Explorer: annotations fix, MDS normalization, light mode, design polish
Browse files- Switch from Scatter3d text to scene annotations (fixes Plotly WebGL text sizing bug)
- Normalize MDS output to [-1,1] with fixed axis ranges for consistent label sizes
- Force light mode via head script redirect
- Dark purple buttons (#63348d), light purple block backgrounds (#f3f0f7)
- Remove button-like styling from Gradio labels
- Per-word palette colors with lighter neighbor shades
- Add Lecture 7 design discussion (discussion.md)
- Add Embedding Explorer section to tools/CLAUDE.md
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
app.py
CHANGED
|
@@ -19,7 +19,6 @@ import re
|
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
import plotly.graph_objects as go
|
| 22 |
-
from sklearn.decomposition import PCA
|
| 23 |
import gradio as gr
|
| 24 |
|
| 25 |
# ββ Configuration (all changeable via HF Space env vars) βββββ
|
|
@@ -49,6 +48,32 @@ DARK = "#1a1a2e"
|
|
| 49 |
GRAY = "#888888"
|
| 50 |
BG = "#fafafa"
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# ββ Load GloVe embeddings on startup βββββββββββββββββββββββββ
|
| 53 |
|
| 54 |
import time
|
|
@@ -114,32 +139,34 @@ def parse_expression(expr):
|
|
| 114 |
return pos, neg, ordered
|
| 115 |
|
| 116 |
|
| 117 |
-
def
|
| 118 |
-
"""
|
|
|
|
|
|
|
| 119 |
n = len(vectors)
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
if nc < 3:
|
| 124 |
coords = np.hstack([coords, np.zeros((n, 3 - nc))])
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
"""Project new vectors into an existing PCA space."""
|
| 130 |
-
coords = pca.transform(vectors)
|
| 131 |
-
if coords.shape[1] < 3:
|
| 132 |
-
coords = np.hstack([coords, np.zeros((len(vectors), 3 - coords.shape[1]))])
|
| 133 |
return coords
|
| 134 |
|
| 135 |
|
| 136 |
def _axis(title=""):
|
| 137 |
-
"""
|
| 138 |
return dict(
|
| 139 |
showgrid=True,
|
| 140 |
-
gridcolor="rgba(99,52,141,0.
|
| 141 |
zeroline=True,
|
| 142 |
-
zerolinecolor="rgba(99,52,141,0.
|
| 143 |
zerolinewidth=2,
|
| 144 |
showticklabels=False,
|
| 145 |
title=title,
|
|
@@ -148,13 +175,19 @@ def _axis(title=""):
|
|
| 148 |
)
|
| 149 |
|
| 150 |
|
| 151 |
-
def layout_3d(height=
|
| 152 |
-
"""Shared Plotly 3D layout."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
return dict(
|
| 154 |
scene=dict(
|
| 155 |
-
xaxis=
|
| 156 |
-
yaxis=
|
| 157 |
-
zaxis=
|
| 158 |
bgcolor=BG,
|
| 159 |
camera=dict(eye=dict(x=1.5, y=1.5, z=1.2)),
|
| 160 |
aspectmode="cube",
|
|
@@ -163,7 +196,7 @@ def layout_3d(height=560):
|
|
| 163 |
margin=dict(l=0, r=0, t=10, b=10),
|
| 164 |
showlegend=True,
|
| 165 |
legend=dict(
|
| 166 |
-
yanchor="top", y=0.99, xanchor="
|
| 167 |
bgcolor="rgba(255,255,255,0.85)",
|
| 168 |
font=dict(family="Inter, sans-serif", size=12),
|
| 169 |
),
|
|
@@ -172,15 +205,18 @@ def layout_3d(height=560):
|
|
| 172 |
)
|
| 173 |
|
| 174 |
|
| 175 |
-
def add_vectors_from_origin(fig, coords, labels,
|
|
|
|
| 176 |
"""Draw vector lines from origin to each point (words ARE vectors)."""
|
| 177 |
for i, label in enumerate(labels):
|
|
|
|
|
|
|
| 178 |
fig.add_trace(go.Scatter3d(
|
| 179 |
x=[0, coords[i, 0]],
|
| 180 |
y=[0, coords[i, 1]],
|
| 181 |
z=[0, coords[i, 2]],
|
| 182 |
mode="lines",
|
| 183 |
-
line=dict(color=
|
| 184 |
showlegend=False, hoverinfo="none",
|
| 185 |
))
|
| 186 |
# Origin marker
|
|
@@ -232,77 +268,141 @@ def explore(words_text, selected):
|
|
| 232 |
|
| 233 |
valid = valid[:12] # cap to keep plot readable
|
| 234 |
|
| 235 |
-
#
|
| 236 |
-
|
| 237 |
-
main_coords, pca = fit_pca_3d(vecs)
|
| 238 |
-
|
| 239 |
-
# Neighbors (projected into the same PCA space)
|
| 240 |
-
nbr_data, nbr_coords = [], None
|
| 241 |
if selected and selected != "(clear)" and selected in VOCAB and selected in valid:
|
| 242 |
nbrs = model.most_similar(selected, topn=N_NEIGHBORS)
|
| 243 |
nbr_data = [(w, s) for w, s in nbrs if w not in valid]
|
| 244 |
-
if nbr_data:
|
| 245 |
-
nv = np.array([model[w] for w, _ in nbr_data])
|
| 246 |
-
nbr_coords = project_3d(pca, nv)
|
| 247 |
else:
|
| 248 |
selected = None
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
# ββ Build figure ββ
|
| 251 |
fig = go.Figure()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
-
|
| 275 |
-
if nbr_data and nbr_coords is not None:
|
| 276 |
fig.add_trace(go.Scatter3d(
|
| 277 |
-
x=
|
| 278 |
-
y=
|
| 279 |
-
z=
|
| 280 |
-
mode="markers
|
| 281 |
-
text=[w for w, _ in nbr_data],
|
| 282 |
-
textposition="top center",
|
| 283 |
-
textfont=dict(size=11, color=GRAY),
|
| 284 |
marker=dict(
|
| 285 |
-
size=
|
| 286 |
-
line=dict(width=
|
| 287 |
),
|
| 288 |
-
name=f
|
| 289 |
hoverinfo="text",
|
| 290 |
-
hovertext=[f"{
|
| 291 |
))
|
|
|
|
|
|
|
| 292 |
|
| 293 |
-
#
|
| 294 |
-
|
| 295 |
-
for i in range(len(nbr_data)):
|
| 296 |
fig.add_trace(go.Scatter3d(
|
| 297 |
-
x=
|
| 298 |
-
y=
|
| 299 |
-
z=
|
| 300 |
-
mode="
|
| 301 |
-
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
# Status
|
| 308 |
status = f"**{len(valid)} words** in 3D"
|
|
@@ -358,14 +458,14 @@ def arithmetic(expression):
|
|
| 358 |
for w in neg:
|
| 359 |
rv -= model[w]
|
| 360 |
|
| 361 |
-
# Collect all words for
|
| 362 |
result_words = [w for w, _ in result_entries]
|
| 363 |
all_words = operands + result_words
|
| 364 |
all_vecs = np.array([model[w] for w in all_words])
|
| 365 |
|
| 366 |
-
# Include result vector in
|
| 367 |
combined = np.vstack([all_vecs, rv.reshape(1, -1)])
|
| 368 |
-
coords_all
|
| 369 |
|
| 370 |
rv_coord = coords_all[-1]
|
| 371 |
coords = coords_all[:-1]
|
|
@@ -409,7 +509,7 @@ def arithmetic(expression):
|
|
| 409 |
mode="markers+text",
|
| 410 |
text=[operands[i] for i in pi],
|
| 411 |
textposition="top center",
|
| 412 |
-
textfont=dict(size=
|
| 413 |
marker=dict(size=10, color=PURPLE, opacity=0.9),
|
| 414 |
name="Positive (+)",
|
| 415 |
hoverinfo="text",
|
|
@@ -426,7 +526,7 @@ def arithmetic(expression):
|
|
| 426 |
mode="markers+text",
|
| 427 |
text=[operands[i] for i in ni],
|
| 428 |
textposition="top center",
|
| 429 |
-
textfont=dict(size=
|
| 430 |
marker=dict(size=10, color=PINK, opacity=0.9),
|
| 431 |
name="Negative (β)",
|
| 432 |
hoverinfo="text",
|
|
@@ -440,7 +540,7 @@ def arithmetic(expression):
|
|
| 440 |
mode="markers+text",
|
| 441 |
text=[f"β {top_result[0]}"],
|
| 442 |
textposition="top center",
|
| 443 |
-
textfont=dict(size=
|
| 444 |
marker=dict(
|
| 445 |
size=14, color=GOLD, opacity=1.0, symbol="diamond",
|
| 446 |
line=dict(width=2, color=PURPLE),
|
|
@@ -460,8 +560,8 @@ def arithmetic(expression):
|
|
| 460 |
mode="markers+text",
|
| 461 |
text=[f"{w} ({s:.2f})" for w, s in other_results],
|
| 462 |
textposition="top center",
|
| 463 |
-
textfont=dict(size=
|
| 464 |
-
marker=dict(size=
|
| 465 |
name="Other matches",
|
| 466 |
hoverinfo="text",
|
| 467 |
hovertext=[f"{w} ({s:.3f})" for w, s in other_results],
|
|
@@ -514,13 +614,58 @@ def arithmetic(expression):
|
|
| 514 |
CSS = """
|
| 515 |
.gradio-container { max-width: 1200px !important; }
|
| 516 |
h1 { color: #63348d !important; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
"""
|
| 518 |
|
|
|
|
|
|
|
| 519 |
with gr.Blocks(
|
| 520 |
title="Embedding Explorer",
|
|
|
|
| 521 |
theme=gr.themes.Soft(
|
| 522 |
primary_hue="purple",
|
| 523 |
font=gr.themes.GoogleFont("Inter"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
),
|
| 525 |
css=CSS,
|
| 526 |
) as demo:
|
|
@@ -543,27 +688,27 @@ with gr.Blocks(
|
|
| 543 |
"Similar words cluster together. "
|
| 544 |
"Click a word below the plot to see its nearest neighbors.*"
|
| 545 |
)
|
|
|
|
| 546 |
with gr.Row():
|
| 547 |
-
with gr.Column(scale=
|
| 548 |
exp_in = gr.Textbox(
|
| 549 |
label="Words (space-separated, min 3)",
|
| 550 |
placeholder="dog cat fish car truck",
|
| 551 |
lines=1,
|
| 552 |
)
|
|
|
|
| 553 |
exp_btn = gr.Button("Show in 3D", variant="primary")
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
with gr.Column(scale=2):
|
| 566 |
-
exp_plot = gr.Plot(label="Embedding Space")
|
| 567 |
|
| 568 |
# Events
|
| 569 |
exp_btn.click(
|
|
@@ -588,22 +733,22 @@ with gr.Blocks(
|
|
| 588 |
"*Word vectors can do math! "
|
| 589 |
"The results reveal hidden relationships between words.*"
|
| 590 |
)
|
|
|
|
| 591 |
with gr.Row():
|
| 592 |
-
with gr.Column(scale=
|
| 593 |
math_in = gr.Textbox(
|
| 594 |
label="Expression",
|
| 595 |
placeholder="king - man + woman",
|
| 596 |
lines=1,
|
| 597 |
)
|
|
|
|
| 598 |
math_btn = gr.Button("Compute", variant="primary")
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
with gr.Column(scale=2):
|
| 606 |
-
math_plot = gr.Plot(label="Vector Arithmetic")
|
| 607 |
|
| 608 |
# Events
|
| 609 |
math_btn.click(
|
|
|
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
import plotly.graph_objects as go
|
|
|
|
| 22 |
import gradio as gr
|
| 23 |
|
| 24 |
# ββ Configuration (all changeable via HF Space env vars) βββββ
|
|
|
|
| 48 |
GRAY = "#888888"
|
| 49 |
BG = "#fafafa"
|
| 50 |
|
| 51 |
+
# Categorical palette for up to 12 words (design-system first, then complements)
|
| 52 |
+
PALETTE = [
|
| 53 |
+
"#63348d", # purple (design primary)
|
| 54 |
+
"#2e86c1", # blue
|
| 55 |
+
"#c0392b", # red
|
| 56 |
+
"#f0c040", # gold (design)
|
| 57 |
+
"#27ae60", # green
|
| 58 |
+
"#de95a0", # pink (design)
|
| 59 |
+
"#e67e22", # orange
|
| 60 |
+
"#1abc9c", # teal
|
| 61 |
+
"#8e44ad", # violet
|
| 62 |
+
"#d4ac0d", # dark gold
|
| 63 |
+
"#2980b9", # steel blue
|
| 64 |
+
"#c0392b", # crimson
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def lighten(hex_color, amount=0.3):
|
| 69 |
+
"""Lighten a hex color by blending toward white. amount=0.3 β 30% lighter."""
|
| 70 |
+
h = hex_color.lstrip("#")
|
| 71 |
+
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
|
| 72 |
+
r = int(r + (255 - r) * amount)
|
| 73 |
+
g = int(g + (255 - g) * amount)
|
| 74 |
+
b = int(b + (255 - b) * amount)
|
| 75 |
+
return f"#{r:02x}{g:02x}{b:02x}"
|
| 76 |
+
|
| 77 |
# ββ Load GloVe embeddings on startup βββββββββββββββββββββββββ
|
| 78 |
|
| 79 |
import time
|
|
|
|
| 139 |
return pos, neg, ordered
|
| 140 |
|
| 141 |
|
| 142 |
+
def reduce_3d(vectors):
|
| 143 |
+
"""MDS (cosine distance) β 3D. Normalizes to [-1, 1] for consistent label sizing."""
|
| 144 |
+
from sklearn.manifold import MDS
|
| 145 |
+
from sklearn.metrics.pairwise import cosine_distances
|
| 146 |
n = len(vectors)
|
| 147 |
+
if n < 2:
|
| 148 |
+
return np.zeros((n, 3))
|
| 149 |
+
dist = cosine_distances(vectors)
|
| 150 |
+
nc = min(3, n)
|
| 151 |
+
mds = MDS(n_components=nc, dissimilarity="precomputed",
|
| 152 |
+
random_state=42, normalized_stress="auto", max_iter=300)
|
| 153 |
+
coords = mds.fit_transform(dist)
|
| 154 |
if nc < 3:
|
| 155 |
coords = np.hstack([coords, np.zeros((n, 3 - nc))])
|
| 156 |
+
# Normalize to [-1, 1] so axis ranges and label sizes are consistent
|
| 157 |
+
max_abs = np.abs(coords).max()
|
| 158 |
+
if max_abs > 1e-8:
|
| 159 |
+
coords = coords / max_abs
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
return coords
|
| 161 |
|
| 162 |
|
| 163 |
def _axis(title=""):
|
| 164 |
+
"""3D axis with visible gridlines for depth perception."""
|
| 165 |
return dict(
|
| 166 |
showgrid=True,
|
| 167 |
+
gridcolor="rgba(99,52,141,0.18)",
|
| 168 |
zeroline=True,
|
| 169 |
+
zerolinecolor="rgba(99,52,141,0.40)",
|
| 170 |
zerolinewidth=2,
|
| 171 |
showticklabels=False,
|
| 172 |
title=title,
|
|
|
|
| 175 |
)
|
| 176 |
|
| 177 |
|
| 178 |
+
def layout_3d(height=700):
|
| 179 |
+
"""Shared Plotly 3D layout. Fixed [-1.3, 1.3] range; preserves camera between updates."""
|
| 180 |
+
ax_x, ax_y, ax_z = _axis(), _axis(), _axis()
|
| 181 |
+
# Fixed range since reduce_3d normalizes to [-1, 1]; padding for labels
|
| 182 |
+
fixed = [-1.3, 1.3]
|
| 183 |
+
ax_x["range"] = fixed
|
| 184 |
+
ax_y["range"] = fixed
|
| 185 |
+
ax_z["range"] = fixed
|
| 186 |
return dict(
|
| 187 |
scene=dict(
|
| 188 |
+
xaxis=ax_x,
|
| 189 |
+
yaxis=ax_y,
|
| 190 |
+
zaxis=ax_z,
|
| 191 |
bgcolor=BG,
|
| 192 |
camera=dict(eye=dict(x=1.5, y=1.5, z=1.2)),
|
| 193 |
aspectmode="cube",
|
|
|
|
| 196 |
margin=dict(l=0, r=0, t=10, b=10),
|
| 197 |
showlegend=True,
|
| 198 |
legend=dict(
|
| 199 |
+
yanchor="top", y=0.99, xanchor="right", x=0.99,
|
| 200 |
bgcolor="rgba(255,255,255,0.85)",
|
| 201 |
font=dict(family="Inter, sans-serif", size=12),
|
| 202 |
),
|
|
|
|
| 205 |
)
|
| 206 |
|
| 207 |
|
| 208 |
+
def add_vectors_from_origin(fig, coords, labels, colors=None, widths=None,
|
| 209 |
+
default_color=PURPLE, default_width=3):
|
| 210 |
"""Draw vector lines from origin to each point (words ARE vectors)."""
|
| 211 |
for i, label in enumerate(labels):
|
| 212 |
+
c = colors[i] if colors else default_color
|
| 213 |
+
w = widths[i] if widths else default_width
|
| 214 |
fig.add_trace(go.Scatter3d(
|
| 215 |
x=[0, coords[i, 0]],
|
| 216 |
y=[0, coords[i, 1]],
|
| 217 |
z=[0, coords[i, 2]],
|
| 218 |
mode="lines",
|
| 219 |
+
line=dict(color=c, width=w),
|
| 220 |
showlegend=False, hoverinfo="none",
|
| 221 |
))
|
| 222 |
# Origin marker
|
|
|
|
| 268 |
|
| 269 |
valid = valid[:12] # cap to keep plot readable
|
| 270 |
|
| 271 |
+
# Gather neighbors first so we can MDS everything together
|
| 272 |
+
nbr_data = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
if selected and selected != "(clear)" and selected in VOCAB and selected in valid:
|
| 274 |
nbrs = model.most_similar(selected, topn=N_NEIGHBORS)
|
| 275 |
nbr_data = [(w, s) for w, s in nbrs if w not in valid]
|
|
|
|
|
|
|
|
|
|
| 276 |
else:
|
| 277 |
selected = None
|
| 278 |
|
| 279 |
+
# Build combined word list and run MDS once
|
| 280 |
+
all_words = valid + [w for w, _ in nbr_data]
|
| 281 |
+
all_vecs = np.array([model[w] for w in all_words])
|
| 282 |
+
all_coords = reduce_3d(all_vecs)
|
| 283 |
+
|
| 284 |
+
main_coords = all_coords[:len(valid)]
|
| 285 |
+
nbr_coords = all_coords[len(valid):] if nbr_data else None
|
| 286 |
+
|
| 287 |
+
# ββ Assign a distinct color to each word from the palette ββ
|
| 288 |
+
word_colors = {w: PALETTE[i % len(PALETTE)] for i, w in enumerate(valid)}
|
| 289 |
+
|
| 290 |
# ββ Build figure ββ
|
| 291 |
fig = go.Figure()
|
| 292 |
+
# Collect 3D annotations (HTML overlays β immune to WebGL text sizing bug)
|
| 293 |
+
annotations = []
|
| 294 |
+
|
| 295 |
+
def add_label(x, y, z, text, size=16, color=DARK, opacity=1.0):
|
| 296 |
+
annotations.append(dict(
|
| 297 |
+
x=x, y=y, z=z, text=text, showarrow=False,
|
| 298 |
+
font=dict(size=size, color=color, family="Inter, sans-serif"),
|
| 299 |
+
opacity=opacity, yshift=18,
|
| 300 |
+
))
|
| 301 |
|
| 302 |
+
if selected:
|
| 303 |
+
si = valid.index(selected)
|
| 304 |
+
sel_color = word_colors[selected]
|
| 305 |
+
nbr_color = lighten(sel_color, 0.3)
|
| 306 |
+
|
| 307 |
+
# Vector lines β selected word's line is bold, others are faded
|
| 308 |
+
vec_colors = [sel_color if w == selected else lighten(word_colors[w], 0.5)
|
| 309 |
+
for w in valid]
|
| 310 |
+
vec_widths = [4 if w == selected else 1.5 for w in valid]
|
| 311 |
+
add_vectors_from_origin(fig, main_coords, valid,
|
| 312 |
+
colors=vec_colors, widths=vec_widths)
|
| 313 |
+
|
| 314 |
+
# Dimmed words (not selected) β markers only
|
| 315 |
+
dim_idx = [i for i in range(len(valid)) if i != si]
|
| 316 |
+
if dim_idx:
|
| 317 |
+
fig.add_trace(go.Scatter3d(
|
| 318 |
+
x=[main_coords[i, 0] for i in dim_idx],
|
| 319 |
+
y=[main_coords[i, 1] for i in dim_idx],
|
| 320 |
+
z=[main_coords[i, 2] for i in dim_idx],
|
| 321 |
+
mode="markers",
|
| 322 |
+
marker=dict(
|
| 323 |
+
size=8,
|
| 324 |
+
color=[lighten(word_colors[valid[i]], 0.4) for i in dim_idx],
|
| 325 |
+
opacity=0.5,
|
| 326 |
+
line=dict(width=1, color="white"),
|
| 327 |
+
),
|
| 328 |
+
name="Words",
|
| 329 |
+
hoverinfo="text",
|
| 330 |
+
hovertext=[valid[i] for i in dim_idx],
|
| 331 |
+
))
|
| 332 |
+
for i in dim_idx:
|
| 333 |
+
add_label(main_coords[i, 0], main_coords[i, 1], main_coords[i, 2],
|
| 334 |
+
valid[i], size=15,
|
| 335 |
+
color=lighten(word_colors[valid[i]], 0.4), opacity=0.7)
|
| 336 |
|
| 337 |
+
# Selected word β bright, full color, markers only
|
|
|
|
| 338 |
fig.add_trace(go.Scatter3d(
|
| 339 |
+
x=[main_coords[si, 0]],
|
| 340 |
+
y=[main_coords[si, 1]],
|
| 341 |
+
z=[main_coords[si, 2]],
|
| 342 |
+
mode="markers",
|
|
|
|
|
|
|
|
|
|
| 343 |
marker=dict(
|
| 344 |
+
size=14, color=sel_color, opacity=1.0,
|
| 345 |
+
line=dict(width=2, color="white"),
|
| 346 |
),
|
| 347 |
+
name=f"Selected: {selected}",
|
| 348 |
hoverinfo="text",
|
| 349 |
+
hovertext=[f"{selected} (selected)"],
|
| 350 |
))
|
| 351 |
+
add_label(main_coords[si, 0], main_coords[si, 1], main_coords[si, 2],
|
| 352 |
+
f"<b>{selected}</b>", size=18, color=DARK)
|
| 353 |
|
| 354 |
+
# Neighbors β lighter shade of the selected word's color
|
| 355 |
+
if nbr_data and nbr_coords is not None:
|
|
|
|
| 356 |
fig.add_trace(go.Scatter3d(
|
| 357 |
+
x=nbr_coords[:, 0].tolist(),
|
| 358 |
+
y=nbr_coords[:, 1].tolist(),
|
| 359 |
+
z=nbr_coords[:, 2].tolist(),
|
| 360 |
+
mode="markers",
|
| 361 |
+
marker=dict(
|
| 362 |
+
size=9, color=nbr_color, opacity=0.9,
|
| 363 |
+
line=dict(width=1, color=sel_color),
|
| 364 |
+
),
|
| 365 |
+
name=f'Near "{selected}"',
|
| 366 |
+
hoverinfo="text",
|
| 367 |
+
hovertext=[f"{w} ({s:.3f})" for w, s in nbr_data],
|
| 368 |
))
|
| 369 |
+
for i, (w, s) in enumerate(nbr_data):
|
| 370 |
+
add_label(nbr_coords[i, 0], nbr_coords[i, 1], nbr_coords[i, 2],
|
| 371 |
+
w, size=16, color=DARK)
|
| 372 |
+
|
| 373 |
+
# Dotted lines from selected word to its neighbors
|
| 374 |
+
for i in range(len(nbr_data)):
|
| 375 |
+
fig.add_trace(go.Scatter3d(
|
| 376 |
+
x=[main_coords[si, 0], nbr_coords[i, 0]],
|
| 377 |
+
y=[main_coords[si, 1], nbr_coords[i, 1]],
|
| 378 |
+
z=[main_coords[si, 2], nbr_coords[i, 2]],
|
| 379 |
+
mode="lines",
|
| 380 |
+
line=dict(color=nbr_color, width=2, dash="dot"),
|
| 381 |
+
showlegend=False, hoverinfo="none",
|
| 382 |
+
))
|
| 383 |
+
else:
|
| 384 |
+
# No selection β each word gets its own color, markers only
|
| 385 |
+
colors = [word_colors[w] for w in valid]
|
| 386 |
+
add_vectors_from_origin(fig, main_coords, valid, colors=colors)
|
| 387 |
|
| 388 |
+
fig.add_trace(go.Scatter3d(
|
| 389 |
+
x=main_coords[:, 0].tolist(),
|
| 390 |
+
y=main_coords[:, 1].tolist(),
|
| 391 |
+
z=main_coords[:, 2].tolist(),
|
| 392 |
+
mode="markers",
|
| 393 |
+
marker=dict(
|
| 394 |
+
size=10, color=colors, opacity=0.9,
|
| 395 |
+
line=dict(width=1, color="white"),
|
| 396 |
+
),
|
| 397 |
+
name="Words",
|
| 398 |
+
hoverinfo="text",
|
| 399 |
+
hovertext=valid,
|
| 400 |
+
))
|
| 401 |
+
for i, w in enumerate(valid):
|
| 402 |
+
add_label(main_coords[i, 0], main_coords[i, 1], main_coords[i, 2],
|
| 403 |
+
w, size=16, color=DARK)
|
| 404 |
+
|
| 405 |
+
fig.update_layout(**layout_3d(), scene_annotations=annotations)
|
| 406 |
|
| 407 |
# Status
|
| 408 |
status = f"**{len(valid)} words** in 3D"
|
|
|
|
| 458 |
for w in neg:
|
| 459 |
rv -= model[w]
|
| 460 |
|
| 461 |
+
# Collect all words for MDS
|
| 462 |
result_words = [w for w, _ in result_entries]
|
| 463 |
all_words = operands + result_words
|
| 464 |
all_vecs = np.array([model[w] for w in all_words])
|
| 465 |
|
| 466 |
+
# Include result vector in MDS fit for best view
|
| 467 |
combined = np.vstack([all_vecs, rv.reshape(1, -1)])
|
| 468 |
+
coords_all = reduce_3d(combined)
|
| 469 |
|
| 470 |
rv_coord = coords_all[-1]
|
| 471 |
coords = coords_all[:-1]
|
|
|
|
| 509 |
mode="markers+text",
|
| 510 |
text=[operands[i] for i in pi],
|
| 511 |
textposition="top center",
|
| 512 |
+
textfont=dict(size=18, color=DARK),
|
| 513 |
marker=dict(size=10, color=PURPLE, opacity=0.9),
|
| 514 |
name="Positive (+)",
|
| 515 |
hoverinfo="text",
|
|
|
|
| 526 |
mode="markers+text",
|
| 527 |
text=[operands[i] for i in ni],
|
| 528 |
textposition="top center",
|
| 529 |
+
textfont=dict(size=18, color=DARK),
|
| 530 |
marker=dict(size=10, color=PINK, opacity=0.9),
|
| 531 |
name="Negative (β)",
|
| 532 |
hoverinfo="text",
|
|
|
|
| 540 |
mode="markers+text",
|
| 541 |
text=[f"β {top_result[0]}"],
|
| 542 |
textposition="top center",
|
| 543 |
+
textfont=dict(size=20, color=PURPLE),
|
| 544 |
marker=dict(
|
| 545 |
size=14, color=GOLD, opacity=1.0, symbol="diamond",
|
| 546 |
line=dict(width=2, color=PURPLE),
|
|
|
|
| 560 |
mode="markers+text",
|
| 561 |
text=[f"{w} ({s:.2f})" for w, s in other_results],
|
| 562 |
textposition="top center",
|
| 563 |
+
textfont=dict(size=14, color=GRAY),
|
| 564 |
+
marker=dict(size=6, color=PURPLE_LIGHT, opacity=0.7),
|
| 565 |
name="Other matches",
|
| 566 |
hoverinfo="text",
|
| 567 |
hovertext=[f"{w} ({s:.3f})" for w, s in other_results],
|
|
|
|
| 614 |
CSS = """
|
| 615 |
.gradio-container { max-width: 1200px !important; }
|
| 616 |
h1 { color: #63348d !important; }
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
/* Remove button-like styling from labels */
|
| 620 |
+
.gradio-container label span {
|
| 621 |
+
border: none !important;
|
| 622 |
+
background: transparent !important;
|
| 623 |
+
box-shadow: none !important;
|
| 624 |
+
padding-left: 0 !important;
|
| 625 |
+
font-weight: 600 !important;
|
| 626 |
+
color: #63348d !important;
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
/* Block backgrounds β light purple instead of near-white */
|
| 630 |
+
.gradio-container .block {
|
| 631 |
+
background: #f3f0f7 !important;
|
| 632 |
+
border: 1px solid #ded9f4 !important;
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
/* Input fields should be white for contrast against purple blocks */
|
| 636 |
+
.gradio-container textarea,
|
| 637 |
+
.gradio-container input[type="text"] {
|
| 638 |
+
background: #ffffff !important;
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
/* Tab styling */
|
| 642 |
+
.gradio-container button.tab-nav-button.selected {
|
| 643 |
+
color: #63348d !important;
|
| 644 |
+
border-bottom-color: #63348d !important;
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
/* Radio items β less boxy */
|
| 648 |
+
.gradio-container .wrap .radio-group label {
|
| 649 |
+
border: none !important;
|
| 650 |
+
background: transparent !important;
|
| 651 |
+
}
|
| 652 |
"""
|
| 653 |
|
| 654 |
+
FORCE_LIGHT = '<script>if(!location.search.includes("__theme=light")){const u=new URL(location);u.searchParams.set("__theme","light");location.replace(u)}</script>'
|
| 655 |
+
|
| 656 |
with gr.Blocks(
|
| 657 |
title="Embedding Explorer",
|
| 658 |
+
head=FORCE_LIGHT,
|
| 659 |
theme=gr.themes.Soft(
|
| 660 |
primary_hue="purple",
|
| 661 |
font=gr.themes.GoogleFont("Inter"),
|
| 662 |
+
).set(
|
| 663 |
+
button_primary_background_fill="#63348d",
|
| 664 |
+
button_primary_background_fill_hover="#4a2769",
|
| 665 |
+
button_primary_text_color="#ffffff",
|
| 666 |
+
block_background_fill="#f3f0f7",
|
| 667 |
+
block_border_color="#ded9f4",
|
| 668 |
+
body_background_fill="#ffffff",
|
| 669 |
),
|
| 670 |
css=CSS,
|
| 671 |
) as demo:
|
|
|
|
| 688 |
"Similar words cluster together. "
|
| 689 |
"Click a word below the plot to see its nearest neighbors.*"
|
| 690 |
)
|
| 691 |
+
exp_plot = gr.Plot(label="Embedding Space")
|
| 692 |
with gr.Row():
|
| 693 |
+
with gr.Column(scale=2):
|
| 694 |
exp_in = gr.Textbox(
|
| 695 |
label="Words (space-separated, min 3)",
|
| 696 |
placeholder="dog cat fish car truck",
|
| 697 |
lines=1,
|
| 698 |
)
|
| 699 |
+
with gr.Column(scale=1):
|
| 700 |
exp_btn = gr.Button("Show in 3D", variant="primary")
|
| 701 |
+
exp_status = gr.Markdown("")
|
| 702 |
+
exp_radio = gr.Radio(
|
| 703 |
+
label="Click to see nearest neighbors",
|
| 704 |
+
choices=[], value=None,
|
| 705 |
+
visible=False, interactive=True,
|
| 706 |
+
)
|
| 707 |
+
gr.Examples(
|
| 708 |
+
examples=[[e] for e in EXPLORE_EXAMPLES],
|
| 709 |
+
inputs=[exp_in],
|
| 710 |
+
label="Try these",
|
| 711 |
+
)
|
|
|
|
|
|
|
| 712 |
|
| 713 |
# Events
|
| 714 |
exp_btn.click(
|
|
|
|
| 733 |
"*Word vectors can do math! "
|
| 734 |
"The results reveal hidden relationships between words.*"
|
| 735 |
)
|
| 736 |
+
math_plot = gr.Plot(label="Vector Arithmetic")
|
| 737 |
with gr.Row():
|
| 738 |
+
with gr.Column(scale=2):
|
| 739 |
math_in = gr.Textbox(
|
| 740 |
label="Expression",
|
| 741 |
placeholder="king - man + woman",
|
| 742 |
lines=1,
|
| 743 |
)
|
| 744 |
+
with gr.Column(scale=1):
|
| 745 |
math_btn = gr.Button("Compute", variant="primary")
|
| 746 |
+
math_status = gr.Markdown("")
|
| 747 |
+
gr.Examples(
|
| 748 |
+
examples=[[e] for e in ARITHMETIC_EXAMPLES],
|
| 749 |
+
inputs=[math_in],
|
| 750 |
+
label="Try these",
|
| 751 |
+
)
|
|
|
|
|
|
|
| 752 |
|
| 753 |
# Events
|
| 754 |
math_btn.click(
|