"""
Embedding Explorer — Interactive word vector visualization
Responsible AI: Technology, Power, and Justice
Huston-Tillotson University
Enter words or vector expressions (comma-separated). Each item becomes
an arrow in 3D. Click an item to see its nearest neighbors.
Configuration (HuggingFace Space environment variables):
EXAMPLES — JSON list of example inputs (words or expressions)
N_NEIGHBORS — number of neighbors to show on click (default 8)
"""
import os
import json
import re
import warnings
import urllib.parse
import subprocess
import pandas # noqa: F401 — import before plotly to avoid circular import
import numpy as np
import plotly.graph_objects as go
import gradio as gr
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
# ── Configuration (all changeable via HF Space env vars) ─────
EXAMPLES = json.loads(os.environ.get("EXAMPLES", json.dumps([
"dog cat fish car truck",
"paris france berlin germany tokyo japan",
"man woman king queen prince princess",
"man - woman, uncle - aunt, man woman uncle aunt",
"aunt - woman + man, man woman uncle aunt",
"nephew - man + woman, man woman nephew niece",
"king - man + woman, man woman king queen",
"paris - france + italy, paris france italy rome",
"sushi - japan + germany, sushi japan germany bratwurst",
"hitler - germany + italy, germany italy hitler mussolini",
])))
N_NEIGHBORS = int(os.environ.get("N_NEIGHBORS", "4"))
# ── Share URL infrastructure ─────────────────────────────────
REBRANDLY_API_KEY = os.environ.get("REBRANDLY_API_KEY", "")
_SPACE_ID = os.environ.get("SPACE_ID", "")
if _SPACE_ID:
_owner, _name = _SPACE_ID.split("/")
_BASE_URL = f"https://{_owner}-{_name}.hf.space/"
else:
_BASE_URL = "http://localhost:7860/"
# ── Course design system colors ──────────────────────────────
PURPLE = "#63348d"
PURPLE_LIGHT = "#ded9f4"
PURPLE_DARK = "#301848"
GOLD = "#f0c040"
PINK = "#de95a0"
DARK = "#1a1a2e"
GRAY = "#888888"
BG = "#fafafa"
# Categorical palette — darkest shade from each design-system color family
# 10 shades per family from color_palette.md (darkest → lightest)
_SHADES = {
"purple": ["#64348d", "#6f3ba4", "#7942bb", "#8455c5", "#906acd",
"#9d80d6", "#ab95de", "#baabe5", "#cbc1ec", "#ddd7f4"],
"blue": ["#344b8d", "#3b58a4", "#4266bb", "#5579c5", "#6a8dcd",
"#809fd6", "#95b2de", "#abc3e5", "#c1d4ec", "#d7e5f4"],
"green": ["#348d64", "#3ba470", "#42bb7c", "#55c588", "#6acd95",
"#80d6a2", "#95deb0", "#abe5bf", "#c1eccf", "#d7f4e0"],
"red": ["#8d3437", "#a43b41", "#bb424a", "#c5555f", "#cd6a75",
"#d6808a", "#de95a0", "#e5abb4", "#ecc1c9", "#f4d7dd"],
"yellow": ["#8d7734", "#a48b3b", "#bb9f42", "#c5ac55", "#cdb86a",
"#d6c480", "#decf95", "#e5daab", "#ece5c1", "#f4efd7"],
"orange": ["#8d5534", "#a4633b", "#bb7242", "#c58355", "#cd946a",
"#d6a580", "#deb695", "#e5c6ab", "#ecd6c1", "#f4e5d7"],
}
# Darkest shade from each family — up to 12 items cycle through these 6
PALETTE = [
_SHADES["purple"][0],
_SHADES["blue"][0],
_SHADES["red"][0],
_SHADES["yellow"][0],
_SHADES["green"][0],
_SHADES["orange"][0],
]
# Map each darkest color to its shade family for lookup
_COLOR_FAMILY = {shades[0]: shades for shades in _SHADES.values()}
def lighten(hex_color, amount=0.3):
"""Lighten a color using the design-system palette shades.
Maps amount (0.0–1.0) to palette shade index. Falls back to
arithmetic blending for colors not in the palette.
"""
base = hex_color.lower()
if base in _COLOR_FAMILY:
shades = _COLOR_FAMILY[base]
idx = min(int(amount * len(shades)), len(shades) - 1)
return shades[idx]
# Fallback for non-palette colors (e.g., GOLD)
h = hex_color.lstrip("#")
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
r = int(r + (255 - r) * amount)
g = int(g + (255 - g) * amount)
b = int(b + (255 - b) * amount)
return f"#{r:02x}{g:02x}{b:02x}"
# ── Load GloVe embeddings on startup ─────────────────────────
import time
import gensim.downloader as api
from gensim.models import KeyedVectors
# Native binary cache — loads ~10x faster than gensim's text format
_CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".cache")
_CACHE_PATH = os.path.join(_CACHE_DIR, "glove-wiki-gigaword-300.kv")
def load_model(name="glove-wiki-gigaword-300", retries=5):
"""Load GloVe vectors. Uses native binary cache for fast startup after first run."""
# Fast path: load from native binary cache (memory-mapped)
if os.path.exists(_CACHE_PATH):
print("=" * 60)
print("Loading GloVe vectors from cache (memory-mapped)...")
print("=" * 60)
t0 = time.time()
m = KeyedVectors.load(_CACHE_PATH, mmap="r")
print(f"Loaded in {time.time() - t0:.1f}s")
return m
# Slow path: download via gensim, then save native cache
for attempt in range(1, retries + 1):
try:
print("=" * 60)
print(f"Downloading GloVe word vectors ({name})...")
if attempt == 1:
print("First run only — ~376 MB download. Will cache for fast startup.")
else:
print(f"Retry {attempt}/{retries}...")
print("=" * 60)
m = api.load(name)
# Save native binary cache for next time
os.makedirs(_CACHE_DIR, exist_ok=True)
m.save(_CACHE_PATH)
print(f"Cached to {_CACHE_DIR} for fast startup next time.")
return m
except Exception as e:
print(f"Attempt {attempt} failed: {e}")
if attempt < retries:
wait = 2 ** attempt
print(f"Retrying in {wait}s...")
time.sleep(wait)
else:
raise
model = load_model()
VOCAB = set(model.key_to_index.keys())
DIMS = model.vector_size
print(f"Ready: {len(VOCAB):,} words, {DIMS} dimensions each")
# ── Helpers ──────────────────────────────────────────────────
def parse_expression(expr):
"""Parse 'king - man + woman' → (positives, negatives, ordered).
ordered is [(word, sign, coeff), ...] for display formatting.
Supports fractional coefficients: '0.5 king - 0.3 man + 1.5 woman'.
"""
tokens = re.findall(r"\d*\.?\d+|[a-z']+|[+-]", expr.lower())
pos, neg, ordered = [], [], []
sign = "+"
coeff = 1.0
for t in tokens:
if t in "+-":
sign = t
coeff = 1.0
elif re.match(r"^\d*\.?\d+$", t):
coeff = float(t)
elif t in VOCAB:
(pos if sign == "+" else neg).append((t, coeff))
ordered.append((t, sign, coeff))
coeff = 1.0
return pos, neg, ordered
def _coeff_str(c):
"""Format coefficient for display. Returns '' for 1.0, '0.5\u00b7' otherwise."""
if c == 1.0:
return ""
if c == int(c):
return f"{int(c)}\u00b7"
return f"{c:g}\u00b7"
def parse_items(text):
"""Parse comma-separated input into items (words or vector expressions).
Returns (items, bad_words) where each item is:
(label, vector, is_expr, operand_words, ordered_ops_or_None)
"""
if not text or not text.strip():
return [], []
items = []
bad_words = []
seen_labels = set()
# Split by comma → parts
parts = [p.strip() for p in text.split(",")]
for part in parts:
if not part:
continue
# Detect expression: contains + or - between word characters or digits
if re.search(r"(?:[a-z']|\d)\s*[+\-]\s*(?:[a-z']|\d)", part.lower()):
# It's an arithmetic expression
pos, neg, ordered = parse_expression(part)
if len(pos) + len(neg) < 2:
# Check for bad words
all_tokens = re.findall(r"[a-z']+", part.lower())
bad_words.extend(t for t in all_tokens if t not in VOCAB)
continue
# Compute result vector
vec = np.zeros(DIMS)
for w, c in pos:
vec += c * model[w]
for w, c in neg:
vec -= c * model[w]
# Build label
label_parts = []
for w, s, c in ordered:
cstr = _coeff_str(c)
if not label_parts:
label_parts.append(f"{cstr}{w}")
elif s == "+":
label_parts.append(f"+ {cstr}{w}")
else:
label_parts.append(f"− {cstr}{w}")
label = " ".join(label_parts)
if label not in seen_labels:
seen_labels.add(label)
operand_words = set(w for w, c in pos + neg)
items.append((label, vec, True, operand_words, list(ordered)))
else:
# Plain words — each word is a separate item
words = re.split(r"\s+", part.lower().strip())
for w in words:
w = w.strip()
if not w:
continue
if w not in VOCAB:
bad_words.append(w)
continue
if w not in seen_labels:
seen_labels.add(w)
items.append((w, model[w], False, {w}, None))
return items, bad_words
def reduce_3d(vectors):
"""MDS (cosine distance) → 3D. Normalizes to [-1, 1] for consistent label sizing."""
from sklearn.manifold import MDS
from sklearn.metrics.pairwise import cosine_distances
n = len(vectors)
if n < 2:
return np.zeros((n, 3))
dist = cosine_distances(vectors)
nc = min(3, n)
mds = MDS(n_components=nc, dissimilarity="precomputed",
random_state=42, normalized_stress="auto", max_iter=300)
coords = mds.fit_transform(dist)
if nc < 3:
coords = np.hstack([coords, np.zeros((n, 3 - nc))])
# Normalize to [-1, 1] so axis ranges and label sizes are consistent
max_abs = np.abs(coords).max()
if max_abs > 1e-8:
coords = coords / max_abs
return coords
def _axis(title=""):
"""3D axis — minimal, no built-in grid (we draw our own floor grid)."""
return dict(
showgrid=False,
zeroline=False,
showticklabels=False,
title=title,
showspikes=False,
showbackground=False,
)
def layout_3d(axis_range=1.3, camera=None):
"""Shared Plotly 3D layout. Uses uirevision to preserve camera across updates."""
ax_x, ax_y, ax_z = _axis(), _axis(), _axis()
fixed = [-axis_range, axis_range]
ax_x["range"] = fixed
ax_y["range"] = fixed
ax_z["range"] = fixed
default_camera = dict(eye=dict(x=1.0, y=1.0, z=0.8))
return dict(
scene=dict(
xaxis=ax_x,
yaxis=ax_y,
zaxis=ax_z,
bgcolor="white",
camera=camera or default_camera,
aspectmode="cube",
),
paper_bgcolor="white",
margin=dict(l=0, r=0, t=10, b=10),
showlegend=True,
legend=dict(
yanchor="top", y=0.99, xanchor="right", x=0.99,
bgcolor="rgba(255,255,255,0.85)",
font=dict(family="Inter, sans-serif", size=12),
),
font=dict(family="Inter, sans-serif"),
autosize=True,
uirevision="keep",
)
ARROW_WIDTH = 10 # line width in pixels
ARROW_HEAD_LENGTH = 0.08 # arrowhead length in coordinate units
ARROW_HEAD_WIDTH = 0.03 # arrowhead half-width in coordinate units
def add_arrow(fig, px, py, pz, color, width=ARROW_WIDTH,
head_length=ARROW_HEAD_LENGTH, head_width=ARROW_HEAD_WIDTH,
sx=0, sy=0, sz=0, dash=None):
"""Draw a vector arrow from (sx,sy,sz) to (px,py,pz) with a flat arrowhead."""
start = np.array([sx, sy, sz])
tip = np.array([px, py, pz])
vec = tip - start
length = np.linalg.norm(vec)
if length < 1e-8:
return
d = vec / length # unit direction
# Shorten line so it doesn't overlap arrowhead
shorten = min(head_length, length * 0.3)
end = tip - d * shorten
# Vector line
fig.add_trace(go.Scatter3d(
x=[start[0], end[0]], y=[start[1], end[1]], z=[start[2], end[2]],
mode="lines", line=dict(color=color, width=width, dash=dash),
showlegend=False, hoverinfo="none",
))
# Flat arrowhead using Mesh3d (diamond-cross pyramid, visible from all angles)
up = np.array([0, 0, 1]) if abs(d[2]) < 0.9 else np.array([0, 1, 0])
p1 = np.cross(d, up)
p1 = p1 / np.linalg.norm(p1)
p2 = np.cross(d, p1)
base = tip - d * head_length
w1 = base + p1 * head_width
w2 = base - p1 * head_width
w3 = base + p2 * head_width
w4 = base - p2 * head_width
# 5 vertices: tip + 4 base points
vx = [tip[0], w1[0], w2[0], w3[0], w4[0]]
vy = [tip[1], w1[1], w2[1], w3[1], w4[1]]
vz = [tip[2], w1[2], w2[2], w3[2], w4[2]]
# 4 triangular faces forming the pyramid
fig.add_trace(go.Mesh3d(
x=vx, y=vy, z=vz,
i=[0, 0, 0, 0], j=[1, 3, 2, 4], k=[3, 2, 4, 1],
color=color, opacity=1.0,
flatshading=True,
lighting=dict(ambient=1, diffuse=0, specular=0, fresnel=0),
showlegend=False, hoverinfo="none",
))
def add_floor_grid(fig, range_val=1.3, step=0.25):
"""Draw a grid on the z=0 plane and axis lines through origin (3B1B style)."""
grid_color = "rgba(99,52,141,0.18)"
axis_color = "rgba(99,52,141,0.60)"
# Grid lines on z=0 plane — batched with None separators for efficiency
vals = np.arange(-range_val, range_val + step / 2, step)
# Lines parallel to Y axis (varying x)
xs, ys, zs = [], [], []
for v in vals:
xs.extend([v, v, None])
ys.extend([-range_val, range_val, None])
zs.extend([0, 0, None])
fig.add_trace(go.Scatter3d(
x=xs, y=ys, z=zs,
mode="lines", line=dict(color=grid_color, width=1),
showlegend=False, hoverinfo="none",
))
# Lines parallel to X axis (varying y)
xs, ys, zs = [], [], []
for v in vals:
xs.extend([-range_val, range_val, None])
ys.extend([v, v, None])
zs.extend([0, 0, None])
fig.add_trace(go.Scatter3d(
x=xs, y=ys, z=zs,
mode="lines", line=dict(color=grid_color, width=1),
showlegend=False, hoverinfo="none",
))
# Three axis lines through origin
for ax in [
([-range_val, range_val], [0, 0], [0, 0]), # X
([0, 0], [-range_val, range_val], [0, 0]), # Y
([0, 0], [0, 0], [-range_val, range_val]), # Z
]:
fig.add_trace(go.Scatter3d(
x=list(ax[0]), y=list(ax[1]), z=list(ax[2]),
mode="lines", line=dict(color=axis_color, width=2),
showlegend=False, hoverinfo="none",
))
def blank(msg):
"""Empty placeholder figure with a centered message."""
fig = go.Figure()
fig.add_annotation(
text=msg, xref="paper", yref="paper", x=0.5, y=0.5,
showarrow=False,
font=dict(size=16, color=GRAY, family="Inter, sans-serif"),
)
fig.update_layout(
xaxis_visible=False, yaxis_visible=False,
height=560, paper_bgcolor="white", plot_bgcolor="white",
margin=dict(l=0, r=0, t=0, b=0),
)
return fig
# ── Share URL helpers ────────────────────────────────────────
def _shorten_url(long_url):
"""Shorten a URL via Rebrandly API (using curl). Falls back to long URL."""
if not REBRANDLY_API_KEY or "localhost" in long_url:
return long_url
try:
payload = json.dumps({
"destination": long_url,
"domain": {"fullName": "go.ropavieja.org"},
})
result = subprocess.run(
[
"curl", "-s", "-X", "POST",
"https://api.rebrandly.com/v1/links",
"-H", "Content-Type: application/json",
"-H", f"apikey: {REBRANDLY_API_KEY}",
"-d", payload,
],
capture_output=True, text=True, timeout=10,
)
if result.returncode != 0 or not result.stdout.strip():
return long_url
data = json.loads(result.stdout)
return f"https://{data['shortUrl']}"
except (subprocess.TimeoutExpired, KeyError, json.JSONDecodeError, OSError) as exc:
print(f"[share] Rebrandly error: {exc}")
return long_url
def _parse_camera(cam_str):
"""Parse compact camera string (ex,ey,ez[,cx,cy,cz,ux,uy,uz]) to Plotly camera dict."""
if not cam_str:
return None
try:
vals = [float(v) for v in cam_str.split(",")]
if len(vals) >= 3:
camera = dict(eye=dict(x=vals[0], y=vals[1], z=vals[2]))
if len(vals) >= 6:
camera["center"] = dict(x=vals[3], y=vals[4], z=vals[5])
if len(vals) >= 9:
camera["up"] = dict(x=vals[6], y=vals[7], z=vals[8])
return camera
except (ValueError, IndexError):
pass
return None
def _encode_camera(camera_json):
"""Encode Plotly camera JSON to compact string for URL params."""
if not camera_json:
return ""
try:
cam = json.loads(camera_json)
eye = cam.get("eye", {})
center = cam.get("center", {})
up = cam.get("up", {})
vals = [
eye.get("x", 1.5), eye.get("y", 1.5), eye.get("z", 1.2),
center.get("x", 0), center.get("y", 0), center.get("z", 0),
up.get("x", 0), up.get("y", 0), up.get("z", 1),
]
return ",".join(f"{v:.2f}" for v in vals)
except (json.JSONDecodeError, TypeError):
return ""
# ── Main visualization ───────────────────────────────────────
def explore(input_text, selected, hidden=None, camera=None, n_neighbors=None):
"""Unified 3D visualization of words and vector expressions.
Args:
input_text: Comma-separated words and/or expressions.
selected: Currently selected item for neighbor display (or None).
hidden: Set of labels to hide from rendering (MDS still uses all items).
camera: Plotly camera dict to set initial view.
n_neighbors: Number of nearest neighbors to show (default N_NEIGHBORS).
Returns:
(fig, status_md, radio_update, all_labels)
"""
if not input_text or not input_text.strip():
return (
blank("Enter words or expressions above to visualize in 3D"),
"",
gr.update(choices=[], value=None, visible=False),
[],
)
items, bad = parse_items(input_text)
if not items:
msg = "No valid items found."
if bad:
msg += f"
Not in vocabulary: {', '.join(bad)}"
return blank(msg), "", gr.update(choices=[], value=None, visible=False), []
items = items[:12]
labels = [item[0] for item in items]
hidden = hidden or set()
# Visible labels for radio choices (exclude hidden)
visible_labels = [l for l in labels if l not in hidden]
# No auto-select — user clicks radio to see neighbors
if selected and selected != "(clear)" and selected in visible_labels:
sel_idx = labels.index(selected)
else:
selected = None
sel_idx = None
# ── Collect all unique operand words for MDS ──
all_words = []
word_set = set()
for _, _, _, ops, _ in items:
for w in ops:
if w not in word_set:
word_set.add(w)
all_words.append(w)
# Find nearest high-D word for each expression (used in status + MDS padding)
expr_nearest = {} # label -> (word, similarity)
for label, vec, is_expr, ops, ordered in items:
if not is_expr:
continue
nearest = model.similar_by_vector(vec, topn=len(ops) + 5)
for w, s in nearest:
if w not in ops:
expr_nearest[label] = (w, s)
break
# Pad with neighbor words if < 3 unique words (breaks MDS collinearity)
helper_words = set()
if len(all_words) < 3:
# Try expression nearest-words first
for label in expr_nearest:
nw = expr_nearest[label][0]
if nw not in word_set:
word_set.add(nw)
all_words.append(nw)
helper_words.add(nw)
if len(all_words) >= 3:
break
# Then try neighbors of any plain word
if len(all_words) < 3:
for w in list(all_words):
if w in helper_words:
continue
for nw, _ in model.most_similar(w, topn=5):
if nw not in word_set:
word_set.add(nw)
all_words.append(nw)
helper_words.add(nw)
if len(all_words) >= 3:
break
if len(all_words) >= 3:
break
# Gather neighbors if something is selected (and not hidden)
nn = n_neighbors if n_neighbors is not None else N_NEIGHBORS
nbr_data = []
if selected is not None:
sel_item = items[sel_idx]
if sel_item[2]: # expression
raw = model.similar_by_vector(sel_item[1], topn=nn + 20)
else:
raw = model.most_similar(sel_item[0], topn=nn + 20)
all_op_words = set()
for _, _, _, ops, _ in items:
all_op_words.update(ops)
label_set = set(labels)
nbr_data = [(w, s) for w, s in raw
if w not in all_op_words and w not in label_set
][:nn]
# ── MDS on all operand words + neighbors ──
mds_words = all_words + [w for w, _ in nbr_data]
if not mds_words:
return blank("No valid words found."), "", gr.update(
choices=[], value=None, visible=False), []
mds_vecs = np.array([model[w] for w in mds_words])
mds_coords = reduce_3d(mds_vecs)
word_3d = {w: mds_coords[i] for i, w in enumerate(all_words)}
nbr_coords = mds_coords[len(all_words):] if nbr_data else None
# ── Compute expression results in 3D ──
extra_points = [] # for dynamic axis range
expr_info = {} # label -> visualization data
for label, vec, is_expr, ops, ordered in items:
if not is_expr:
continue
pos_words = [w for w, s, c in ordered if s == "+"]
neg_words = [w for w, s, c in ordered if s == "-"]
pos_coeffs = [c for w, s, c in ordered if s == "+"]
neg_coeffs = [c for w, s, c in ordered if s == "-"]
if len(neg_words) == 0 and len(pos_words) == 2:
# Simple addition: chain tip-to-tail + gold result from origin
a_3d = word_3d[pos_words[0]]
b_3d = word_3d[pos_words[1]]
result_3d = pos_coeffs[0] * a_3d + pos_coeffs[1] * b_3d
extra_points.append(result_3d)
expr_info[label] = ('add', pos_words[0], pos_words[1],
pos_coeffs[0], pos_coeffs[1], result_3d)
else:
# General: chain through operands + gold result from origin
cursor = np.zeros(3)
chain = []
for w, s, coeff in ordered:
prev = cursor.copy()
c = word_3d[w]
cursor = cursor + coeff * c if s == "+" else cursor - coeff * c
chain.append((prev.copy(), cursor.copy(), w, s))
extra_points.append(cursor.copy())
expr_info[label] = ('chain', chain, cursor.copy())
# ── Dynamic axis range (computed from ALL items, not just visible) ──
all_rendered = [word_3d[w] for w in all_words] + extra_points
if nbr_coords is not None:
all_rendered.extend(nbr_coords)
if all_rendered:
max_abs = np.abs(np.array(all_rendered)).max()
axis_range = max(1.3, max_abs * 1.15)
else:
axis_range = 1.3
# ── Colors ──
item_colors = [PALETTE[i % len(PALETTE)] for i in range(len(items))]
# ── Build figure ──
fig = go.Figure()
add_floor_grid(fig, range_val=axis_range)
annotations = []
def add_label(x, y, z, text, size=16, color=DARK, opacity=1.0,
outward=0.07):
"""Add a 3D text label, shifted outward from origin to avoid overlapping vectors."""
pt = np.array([x, y, z])
norm = np.linalg.norm(pt)
if norm > 1e-8 and outward > 0:
pt = pt + (pt / norm) * outward
annotations.append(dict(
x=float(pt[0]), y=float(pt[1]), z=float(pt[2]),
text=text, showarrow=False,
font=dict(size=size, color=color, family="Inter, sans-serif"),
opacity=opacity, yshift=12,
))
for idx, (label, vec, is_expr, ops, ordered) in enumerate(items):
# Skip hidden items
if label in hidden:
continue
color = item_colors[idx]
is_sel = (sel_idx is not None and idx == sel_idx)
is_dim = (sel_idx is not None and idx != sel_idx)
# Selection-aware styling
arr_color = lighten(color, 0.5) if is_dim else color
arr_width = 4 if is_dim else (12 if is_sel else ARROW_WIDTH)
lbl_color = lighten(color, 0.5) if is_dim else color
lbl_opacity = 0.7 if is_dim else 1.0
lbl_size = 18 if is_sel else (15 if is_dim else 16)
gold = lighten(GOLD, 0.3) if is_dim else GOLD
gold_width = 6 if is_dim else 12
if not is_expr:
# ── Plain word: arrow from origin ──
c = word_3d[label]
add_arrow(fig, c[0], c[1], c[2], arr_color, width=arr_width)
txt = f"{label}" if is_sel else label
add_label(c[0], c[1], c[2], txt, size=lbl_size,
color=lbl_color, opacity=lbl_opacity)
else:
info = expr_info[label]
# ── Operand arrows from origin (full length — shows where the word IS) ──
for w, s, coeff in ordered:
c = word_3d[w]
add_arrow(fig, c[0], c[1], c[2], arr_color, width=arr_width)
txt = f"{w}" if is_sel else w
add_label(c[0], c[1], c[2], txt, size=lbl_size,
color=lbl_color, opacity=lbl_opacity)
# ── Construction arrows with expression labels beside midpoint ──
gold_lbl = f"{label}"
gold_lbl_size = 14 if is_dim else 15
gold_lbl_color = "#b08820" if is_dim else "#8a6a10"
def _label_beside(start, end):
"""Place label at midpoint of start→end, offset perpendicular
to the arrow (on the side farther from origin)."""
mid = (start + end) / 2
d = end - start
length = np.linalg.norm(d)
if length < 1e-8:
return mid
d = d / length
up = np.array([0., 0., 1.]) if abs(d[2]) < 0.9 \
else np.array([0., 1., 0.])
perp = np.cross(d, up)
pn = np.linalg.norm(perp)
if pn < 1e-8:
return mid
perp = perp / pn * 0.12
# Pick side farther from origin (pushes label outward)
if np.dot(mid + perp, mid + perp) >= np.dot(mid - perp, mid - perp):
return mid + perp
return mid - perp
if info[0] == 'add':
# Second operand drawn from tip of first (chain)
# info = ('add', word_a, word_b, coeff_a, coeff_b, result_3d)
a = word_3d[info[1]]
coeff_a = info[3]
coeff_b = info[4]
result_3d = info[5]
chain_start = coeff_a * a
chain_color = lighten(color, 0.3) if is_dim else lighten(color, 0.2)
add_arrow(fig, result_3d[0], result_3d[1], result_3d[2],
chain_color, width=arr_width,
sx=chain_start[0], sy=chain_start[1], sz=chain_start[2],
dash="dot")
# Gold result from origin
add_arrow(fig, result_3d[0], result_3d[1], result_3d[2],
gold, width=gold_width)
origin = np.zeros(3)
lpt = _label_beside(origin, result_3d)
add_label(lpt[0], lpt[1], lpt[2], gold_lbl,
size=gold_lbl_size, color=gold_lbl_color,
opacity=lbl_opacity, outward=0)
else: # chain
chain_steps = info[1]
result_3d = info[2]
chain_color = lighten(color, 0.3) if is_dim else lighten(color, 0.2)
for i, (start, end, w, s) in enumerate(chain_steps):
if i == 0:
continue # first step overlaps operand arrow
add_arrow(fig, end[0], end[1], end[2],
chain_color, width=arr_width,
sx=start[0], sy=start[1], sz=start[2], dash="dot")
# Gold result from origin
add_arrow(fig, result_3d[0], result_3d[1], result_3d[2],
gold, width=gold_width)
origin = np.zeros(3)
lpt = _label_beside(origin, result_3d)
add_label(lpt[0], lpt[1], lpt[2], gold_lbl,
size=gold_lbl_size, color=gold_lbl_color,
opacity=lbl_opacity, outward=0)
# (helper_words are invisible — only present for MDS geometry)
# ── Neighbors ──
if selected is not None and nbr_data and nbr_coords is not None:
sel_color = item_colors[sel_idx]
nbr_color = lighten(sel_color, 0.3)
for i, (w, s) in enumerate(nbr_data):
add_arrow(fig,
nbr_coords[i, 0], nbr_coords[i, 1], nbr_coords[i, 2],
nbr_color, width=ARROW_WIDTH,
sx=0, sy=0, sz=0, dash="dot")
add_label(nbr_coords[i, 0], nbr_coords[i, 1], nbr_coords[i, 2],
w, size=16, color=DARK)
fig.update_layout(**layout_3d(axis_range=axis_range, camera=camera),
scene_annotations=annotations)
# ── Status text ──
n_visible = sum(1 for l, _, ie, _, _ in items if l not in hidden)
n_hidden = len(hidden & set(labels))
n_words = sum(1 for l, _, ie, _, _ in items if not ie and l not in hidden)
n_expr = sum(1 for l, _, ie, _, _ in items if ie and l not in hidden)
parts = []
if n_words:
parts.append(f"**{n_words} word{'s' if n_words != 1 else ''}**")
if n_expr:
parts.append(f"**{n_expr} expression{'s' if n_expr != 1 else ''}**")
status = " + ".join(parts) + " in 3D" if parts else "Nothing visible"
if n_hidden:
status += f" · {n_hidden} hidden"
if bad:
status += f" · Not found: {', '.join(bad)}"
for label in expr_nearest:
if label not in hidden:
w, s = expr_nearest[label]
status += f" · **{label} ≈ {w}** ({s:.3f})"
if nbr_data:
status += f" · {len(nbr_data)} neighbors of **{selected}**"
choices = ["(clear)"] + visible_labels
return (
fig,
status,
gr.update(choices=choices, value=selected or "(clear)", visible=True),
labels,
)
# ── Gradio UI ────────────────────────────────────────────────
CSS = """
.gradio-container { max-width: 100% !important; padding: 0 2rem !important; }
h1 { color: #63348d !important; }
/* Example buttons — dark purple outline */
.purple-examples td {
border: 2px solid #63348d !important;
border-radius: 6px !important;
color: #301848 !important;
cursor: pointer !important;
}
.purple-examples td:hover {
background: #ded9f4 !important;
}
/* Radio neighbor selector — purple text, white-on-purple when selected */
.nbr-radio label {
color: #63348d !important;
border: 1px solid #63348d !important;
border-radius: 6px !important;
}
.nbr-radio label.selected {
background: #63348d !important;
color: #ffffff !important;
}
.nbr-radio label.selected * {
color: #ffffff !important;
}
/* Visibility checkboxes — compact */
.vis-cbg label {
color: #63348d !important;
border: 1px solid #63348d !important;
border-radius: 6px !important;
}
.vis-cbg label.selected {
background: #63348d !important;
color: #ffffff !important;
}
.vis-cbg label.selected * {
color: #ffffff !important;
}
/* 3D viewport — full-width 16:9 with border */
.plot-viewport {
border: 2px solid #ded9f4 !important;
border-radius: 8px !important;
}
.plot-viewport .plot-container {
aspect-ratio: 16 / 9 !important;
width: 100% !important;
}
.plot-viewport .js-plotly-plot, .plot-viewport .plotly {
width: 100% !important;
height: 100% !important;
}
/* Neighbors dropdown — compact */
.nn-dropdown {
max-width: 100px !important;
}
.nn-dropdown select {
color: #63348d !important;
}
/* Hidden camera state textbox (visible=False prevents DOM rendering in Gradio 6) */
.camera-hidden { display: none !important; }
/* Input fields — white for contrast */
textarea, input[type="text"] {
background: #ffffff !important;
}
"""
FORCE_LIGHT = """
"""
_LIGHT = {
"button_primary_background_fill": "#63348d",
"button_primary_background_fill_hover": "#4a2769",
"button_primary_text_color": "#ffffff",
"block_background_fill": "#f3f0f7",
"block_border_color": "#ded9f4",
"body_background_fill": "#ffffff",
"body_text_color": "#1a1a2e",
"block_label_text_color": "#63348d",
"block_title_text_color": "#63348d",
"background_fill_primary": "#ffffff",
"background_fill_secondary": "#f3f0f7",
"input_background_fill": "#ffffff",
"input_background_fill_focus": "#ffffff",
"input_border_color": "#ded9f4",
"input_border_color_focus": "#63348d",
"input_placeholder_color": "#888888",
"border_color_primary": "#ded9f4",
"border_color_accent": "#63348d",
"panel_background_fill": "#f3f0f7",
}
# Mirror every light value into _dark so dark mode looks identical
_ALL = {}
for k, v in _LIGHT.items():
_ALL[k] = v
_ALL[k + "_dark"] = v
THEME = gr.themes.Soft(
primary_hue="purple",
font=gr.themes.GoogleFont("Inter"),
).set(**_ALL)
with gr.Blocks(title="Embedding Explorer") as demo:
# ── State ──
all_labels_state = gr.State([])
loading_share = gr.State(False) # suppress cascading events during share load
camera_txt = gr.Textbox(elem_id="camera_txt", elem_classes=["camera-hidden"])
share_params = gr.State({})
# Force light mode fallback (head param covers most cases, this catches HF Spaces)
gr.HTML('')
gr.Markdown(
"# Embedding Explorer\n"
"Words in AI models are stored as **vectors** — long lists of numbers "
"that encode meaning. Similar words have similar vectors. "
"You can even do **vector math**: *king − man + woman ≈ queen*. "
"This tool lets you explore these representations in 3D using "
"[GloVe](https://nlp.stanford.edu/projects/glove/) word vectors "
f"({len(VOCAB):,} words, {DIMS} dimensions)."
)
gr.Markdown(
"*Enter words to see them in 3D, or try vector arithmetic "
"with `+` and `−`. Separate groups with commas. "
"Click an item below the plot to see its nearest neighbors.*"
)
with gr.Row():
with gr.Column(scale=2):
exp_in = gr.Textbox(
label="Words or expressions (comma-separated)",
placeholder="dog cat fish or king - man + woman",
lines=1,
)
with gr.Column(scale=1):
with gr.Row():
exp_btn = gr.Button("Explore", variant="primary")
share_btn = gr.Button("Share", variant="secondary",
scale=0, min_width=80)
share_url = gr.Textbox(label="Share URL", visible=False,
interactive=False, buttons=["copy"])
with gr.Column(elem_classes=["purple-examples"]):
gr.Examples(
examples=[[e] for e in EXAMPLES],
inputs=[exp_in],
label="Try these",
)
exp_plot = gr.Plot(label="Embedding Space", elem_classes=["plot-viewport"])
exp_status = gr.Markdown("")
vis_cbg = gr.CheckboxGroup(
label="Visible items (uncheck to hide)",
choices=[], value=[],
visible=False, interactive=True,
elem_classes=["vis-cbg"],
)
with gr.Row():
exp_radio = gr.Radio(
label="Click to see nearest neighbors",
choices=[], value=None,
visible=False, interactive=True,
elem_classes=["nbr-radio"],
)
nn_dropdown = gr.Dropdown(
label="Neighbors",
choices=[str(i) for i in range(3, 13)],
value=str(N_NEIGHBORS),
interactive=True,
scale=0,
min_width=90,
elem_classes=["nn-dropdown"],
)
# ── Event handlers ──
def _parse_camera_json(camera_json):
"""Parse camera JSON string (from JS bridge) into Plotly camera dict."""
if not camera_json:
return None
try:
return json.loads(camera_json)
except (json.JSONDecodeError, TypeError):
return None
def _get_nn(nn_val):
"""Parse neighbor count from dropdown value."""
try:
return int(nn_val)
except (TypeError, ValueError):
return N_NEIGHBORS
def on_explore(input_text, nn_val=None):
"""Fresh explore — compute MDS, show all items, reset checkboxes.
Supports @word syntax to auto-select a word for neighbors:
dog cat fish @dog → plots all 3, shows dog's neighbors
"""
nn = _get_nn(nn_val)
selected = None
if input_text and "@" in input_text:
match = re.search(r"@(\S+)", input_text)
if match:
selected = match.group(1).lower()
input_text = re.sub(r"\s*@\S+", "", input_text).strip()
fig, status, radio, labels = explore(input_text, selected, n_neighbors=nn)
cbg = gr.update(choices=labels, value=labels, visible=bool(labels))
return fig, status, radio, labels, cbg, gr.update(value=input_text)
def on_radio(input_text, selected, all_labels, visible, camera_json, is_loading, nn_val):
"""Neighbor selection — re-render with current visibility + camera."""
if is_loading:
return gr.update(), gr.update(), gr.update(), False
nn = _get_nn(nn_val)
hidden = set(all_labels) - set(visible) if all_labels and visible else set()
camera = _parse_camera_json(camera_json)
fig, status, radio, _ = explore(input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn)
return fig, status, radio, False
def on_visibility(input_text, selected, all_labels, visible, camera_json, is_loading, nn_val):
"""Visibility toggle — re-render with updated hidden set + camera."""
if is_loading:
return gr.update(), gr.update(), gr.update(), False
nn = _get_nn(nn_val)
hidden = set(all_labels) - set(visible) if all_labels else set()
# If selected item is now hidden, clear selection
if selected and selected != "(clear)" and selected in hidden:
selected = None
camera = _parse_camera_json(camera_json)
fig, status, radio, _ = explore(input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn)
return fig, status, radio, False
def on_nn_change(input_text, selected, all_labels, visible, camera_json, is_loading, nn_val):
"""Neighbor count changed — re-render if a word is selected."""
if is_loading:
return gr.update(), gr.update(), gr.update(), False
if not selected or selected == "(clear)":
return gr.update(), gr.update(), gr.update(), False
nn = _get_nn(nn_val)
hidden = set(all_labels) - set(visible) if all_labels and visible else set()
camera = _parse_camera_json(camera_json)
fig, status, radio, _ = explore(input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn)
return fig, status, radio, False
def on_share(input_text, selected, visible, camera_json, nn_val, request: gr.Request):
"""Build share URL encoding current state."""
params = {}
if input_text and input_text.strip():
params["q"] = input_text.strip()
if selected and selected != "(clear)":
params["sel"] = selected
# Only encode visibility if some items are hidden
if visible is not None and isinstance(visible, list):
params["vis"] = ",".join(visible)
if camera_json:
encoded = _encode_camera(camera_json)
if encoded:
params["cam"] = encoded
nn = _get_nn(nn_val)
if nn != N_NEIGHBORS:
params["nn"] = str(nn)
if not params.get("q"):
return gr.update(value="Nothing to share", visible=True)
# Build base URL from request (gets correct port for local dev)
base_url = _BASE_URL
if request:
host = request.headers.get("host", "")
if host:
scheme = "https" if _SPACE_ID else "http"
base_url = f"{scheme}://{host}/"
long_url = base_url + "?" + urllib.parse.urlencode(params)
# On localhost, just return the full URL (Rebrandly rejects non-public URLs)
if "localhost" in long_url or "127.0.0.1" in long_url:
return gr.update(value=long_url, visible=True)
short = _shorten_url(long_url)
return gr.update(value=short, visible=True)
# ── Wire up events ──
_EXAMPLE_SET = set(EXAMPLES)
def on_input_change(input_text, nn_val):
"""Auto-explore when input matches an example (set by gr.Examples click)."""
if input_text and input_text.strip() in _EXAMPLE_SET:
return on_explore(input_text, nn_val)
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
exp_in.change(
on_input_change,
inputs=[exp_in, nn_dropdown],
outputs=[exp_plot, exp_status, exp_radio, all_labels_state, vis_cbg, exp_in],
)
exp_btn.click(
on_explore,
inputs=[exp_in, nn_dropdown],
outputs=[exp_plot, exp_status, exp_radio, all_labels_state, vis_cbg, exp_in],
)
exp_in.submit(
on_explore,
inputs=[exp_in, nn_dropdown],
outputs=[exp_plot, exp_status, exp_radio, all_labels_state, vis_cbg, exp_in],
)
# Radio + visibility + nn: camera_txt is kept up-to-date by polling script
exp_radio.change(
on_radio,
inputs=[exp_in, exp_radio, all_labels_state, vis_cbg, camera_txt, loading_share, nn_dropdown],
outputs=[exp_plot, exp_status, exp_radio, loading_share],
)
vis_cbg.change(
on_visibility,
inputs=[exp_in, exp_radio, all_labels_state, vis_cbg, camera_txt, loading_share, nn_dropdown],
outputs=[exp_plot, exp_status, exp_radio, loading_share],
)
nn_dropdown.change(
on_nn_change,
inputs=[exp_in, exp_radio, all_labels_state, vis_cbg, camera_txt, loading_share, nn_dropdown],
outputs=[exp_plot, exp_status, exp_radio, loading_share],
)
# Share: camera_txt kept up-to-date by polling script
share_btn.click(
fn=on_share,
inputs=[exp_in, exp_radio, vis_cbg, camera_txt, nn_dropdown],
outputs=[share_url],
)
# ── Share URL loading ──
def load_share_params(request: gr.Request):
"""Step 1: Parse query params from URL."""
qp = dict(request.query_params) if request else {}
return qp
def apply_share_params(params):
"""Step 2: Apply share params — set input, run explore, apply visibility + camera + nn."""
if not params or "q" not in params:
# Check if nn param is present even without q
nn_str = params.get("nn") if params else None
nn_update = gr.update(value=nn_str) if nn_str else gr.update()
return (
gr.update(), # exp_in
gr.update(), # exp_plot
gr.update(), # exp_status
gr.update(), # exp_radio
gr.update(), # vis_cbg
[], # all_labels_state
gr.update(), # camera_txt
False, # loading_share
nn_update, # nn_dropdown
)
input_text = params.get("q", "")
selected = params.get("sel")
if selected == "":
selected = None
vis_str = params.get("vis")
cam_str = params.get("cam")
nn_str = params.get("nn")
camera = _parse_camera(cam_str)
nn = int(nn_str) if nn_str and nn_str.isdigit() else None
# First explore with all items visible to get labels
_, _, _, labels = explore(input_text, None, camera=camera, n_neighbors=nn)
# Apply visibility
if vis_str:
visible = [v.strip() for v in vis_str.split(",")]
hidden = set(labels) - set(visible)
else:
visible = labels
hidden = set()
fig, status, radio, _ = explore(
input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn
)
cbg = gr.update(
choices=labels,
value=visible,
visible=bool(labels),
)
# Pre-populate camera_txt so subsequent re-renders preserve camera
camera_json = json.dumps(camera) if camera else ""
nn_update = gr.update(value=str(nn)) if nn else gr.update()
return (
gr.update(value=input_text),
fig,
status,
radio,
cbg,
labels,
gr.update(value=camera_json),
True, # loading_share — suppress cascading events
nn_update,
)
demo.load(
fn=load_share_params,
outputs=[share_params],
).then(
fn=apply_share_params,
inputs=[share_params],
outputs=[exp_in, exp_plot, exp_status, exp_radio, vis_cbg, all_labels_state, camera_txt, loading_share, nn_dropdown],
)
demo.launch(theme=THEME, css=CSS, head=FORCE_LIGHT)