""" Embedding Explorer — Interactive word vector visualization Responsible AI: Technology, Power, and Justice Huston-Tillotson University Enter words or vector expressions (comma-separated). Each item becomes an arrow in 3D. Click an item to see its nearest neighbors. Configuration (HuggingFace Space environment variables): EXAMPLES — JSON list of example inputs (words or expressions) N_NEIGHBORS — number of neighbors to show on click (default 8) """ import os import json import re import warnings import urllib.parse import subprocess import pandas # noqa: F401 — import before plotly to avoid circular import import numpy as np import plotly.graph_objects as go import gradio as gr warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn") # ── Configuration (all changeable via HF Space env vars) ───── EXAMPLES = json.loads(os.environ.get("EXAMPLES", json.dumps([ "dog cat fish car truck", "paris france berlin germany tokyo japan", "man woman king queen prince princess", "man - woman, uncle - aunt, man woman uncle aunt", "aunt - woman + man, man woman uncle aunt", "nephew - man + woman, man woman nephew niece", "king - man + woman, man woman king queen", "paris - france + italy, paris france italy rome", "sushi - japan + germany, sushi japan germany bratwurst", "hitler - germany + italy, germany italy hitler mussolini", ]))) N_NEIGHBORS = int(os.environ.get("N_NEIGHBORS", "4")) # ── Share URL infrastructure ───────────────────────────────── REBRANDLY_API_KEY = os.environ.get("REBRANDLY_API_KEY", "") _SPACE_ID = os.environ.get("SPACE_ID", "") if _SPACE_ID: _owner, _name = _SPACE_ID.split("/") _BASE_URL = f"https://{_owner}-{_name}.hf.space/" else: _BASE_URL = "http://localhost:7860/" # ── Course design system colors ────────────────────────────── PURPLE = "#63348d" PURPLE_LIGHT = "#ded9f4" PURPLE_DARK = "#301848" GOLD = "#f0c040" PINK = "#de95a0" DARK = "#1a1a2e" GRAY = "#888888" BG = "#fafafa" # Categorical palette — darkest shade from each design-system color family # 10 shades per family from color_palette.md (darkest → lightest) _SHADES = { "purple": ["#64348d", "#6f3ba4", "#7942bb", "#8455c5", "#906acd", "#9d80d6", "#ab95de", "#baabe5", "#cbc1ec", "#ddd7f4"], "blue": ["#344b8d", "#3b58a4", "#4266bb", "#5579c5", "#6a8dcd", "#809fd6", "#95b2de", "#abc3e5", "#c1d4ec", "#d7e5f4"], "green": ["#348d64", "#3ba470", "#42bb7c", "#55c588", "#6acd95", "#80d6a2", "#95deb0", "#abe5bf", "#c1eccf", "#d7f4e0"], "red": ["#8d3437", "#a43b41", "#bb424a", "#c5555f", "#cd6a75", "#d6808a", "#de95a0", "#e5abb4", "#ecc1c9", "#f4d7dd"], "yellow": ["#8d7734", "#a48b3b", "#bb9f42", "#c5ac55", "#cdb86a", "#d6c480", "#decf95", "#e5daab", "#ece5c1", "#f4efd7"], "orange": ["#8d5534", "#a4633b", "#bb7242", "#c58355", "#cd946a", "#d6a580", "#deb695", "#e5c6ab", "#ecd6c1", "#f4e5d7"], } # Darkest shade from each family — up to 12 items cycle through these 6 PALETTE = [ _SHADES["purple"][0], _SHADES["blue"][0], _SHADES["red"][0], _SHADES["yellow"][0], _SHADES["green"][0], _SHADES["orange"][0], ] # Map each darkest color to its shade family for lookup _COLOR_FAMILY = {shades[0]: shades for shades in _SHADES.values()} def lighten(hex_color, amount=0.3): """Lighten a color using the design-system palette shades. Maps amount (0.0–1.0) to palette shade index. Falls back to arithmetic blending for colors not in the palette. """ base = hex_color.lower() if base in _COLOR_FAMILY: shades = _COLOR_FAMILY[base] idx = min(int(amount * len(shades)), len(shades) - 1) return shades[idx] # Fallback for non-palette colors (e.g., GOLD) h = hex_color.lstrip("#") r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) r = int(r + (255 - r) * amount) g = int(g + (255 - g) * amount) b = int(b + (255 - b) * amount) return f"#{r:02x}{g:02x}{b:02x}" # ── Load GloVe embeddings on startup ───────────────────────── import time import gensim.downloader as api from gensim.models import KeyedVectors # Native binary cache — loads ~10x faster than gensim's text format _CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".cache") _CACHE_PATH = os.path.join(_CACHE_DIR, "glove-wiki-gigaword-300.kv") def load_model(name="glove-wiki-gigaword-300", retries=5): """Load GloVe vectors. Uses native binary cache for fast startup after first run.""" # Fast path: load from native binary cache (memory-mapped) if os.path.exists(_CACHE_PATH): print("=" * 60) print("Loading GloVe vectors from cache (memory-mapped)...") print("=" * 60) t0 = time.time() m = KeyedVectors.load(_CACHE_PATH, mmap="r") print(f"Loaded in {time.time() - t0:.1f}s") return m # Slow path: download via gensim, then save native cache for attempt in range(1, retries + 1): try: print("=" * 60) print(f"Downloading GloVe word vectors ({name})...") if attempt == 1: print("First run only — ~376 MB download. Will cache for fast startup.") else: print(f"Retry {attempt}/{retries}...") print("=" * 60) m = api.load(name) # Save native binary cache for next time os.makedirs(_CACHE_DIR, exist_ok=True) m.save(_CACHE_PATH) print(f"Cached to {_CACHE_DIR} for fast startup next time.") return m except Exception as e: print(f"Attempt {attempt} failed: {e}") if attempt < retries: wait = 2 ** attempt print(f"Retrying in {wait}s...") time.sleep(wait) else: raise model = load_model() VOCAB = set(model.key_to_index.keys()) DIMS = model.vector_size print(f"Ready: {len(VOCAB):,} words, {DIMS} dimensions each") # ── Helpers ────────────────────────────────────────────────── def parse_expression(expr): """Parse 'king - man + woman' → (positives, negatives, ordered). ordered is [(word, sign, coeff), ...] for display formatting. Supports fractional coefficients: '0.5 king - 0.3 man + 1.5 woman'. """ tokens = re.findall(r"\d*\.?\d+|[a-z']+|[+-]", expr.lower()) pos, neg, ordered = [], [], [] sign = "+" coeff = 1.0 for t in tokens: if t in "+-": sign = t coeff = 1.0 elif re.match(r"^\d*\.?\d+$", t): coeff = float(t) elif t in VOCAB: (pos if sign == "+" else neg).append((t, coeff)) ordered.append((t, sign, coeff)) coeff = 1.0 return pos, neg, ordered def _coeff_str(c): """Format coefficient for display. Returns '' for 1.0, '0.5\u00b7' otherwise.""" if c == 1.0: return "" if c == int(c): return f"{int(c)}\u00b7" return f"{c:g}\u00b7" def parse_items(text): """Parse comma-separated input into items (words or vector expressions). Returns (items, bad_words) where each item is: (label, vector, is_expr, operand_words, ordered_ops_or_None) """ if not text or not text.strip(): return [], [] items = [] bad_words = [] seen_labels = set() # Split by comma → parts parts = [p.strip() for p in text.split(",")] for part in parts: if not part: continue # Detect expression: contains + or - between word characters or digits if re.search(r"(?:[a-z']|\d)\s*[+\-]\s*(?:[a-z']|\d)", part.lower()): # It's an arithmetic expression pos, neg, ordered = parse_expression(part) if len(pos) + len(neg) < 2: # Check for bad words all_tokens = re.findall(r"[a-z']+", part.lower()) bad_words.extend(t for t in all_tokens if t not in VOCAB) continue # Compute result vector vec = np.zeros(DIMS) for w, c in pos: vec += c * model[w] for w, c in neg: vec -= c * model[w] # Build label label_parts = [] for w, s, c in ordered: cstr = _coeff_str(c) if not label_parts: label_parts.append(f"{cstr}{w}") elif s == "+": label_parts.append(f"+ {cstr}{w}") else: label_parts.append(f"− {cstr}{w}") label = " ".join(label_parts) if label not in seen_labels: seen_labels.add(label) operand_words = set(w for w, c in pos + neg) items.append((label, vec, True, operand_words, list(ordered))) else: # Plain words — each word is a separate item words = re.split(r"\s+", part.lower().strip()) for w in words: w = w.strip() if not w: continue if w not in VOCAB: bad_words.append(w) continue if w not in seen_labels: seen_labels.add(w) items.append((w, model[w], False, {w}, None)) return items, bad_words def reduce_3d(vectors): """MDS (cosine distance) → 3D. Normalizes to [-1, 1] for consistent label sizing.""" from sklearn.manifold import MDS from sklearn.metrics.pairwise import cosine_distances n = len(vectors) if n < 2: return np.zeros((n, 3)) dist = cosine_distances(vectors) nc = min(3, n) mds = MDS(n_components=nc, dissimilarity="precomputed", random_state=42, normalized_stress="auto", max_iter=300) coords = mds.fit_transform(dist) if nc < 3: coords = np.hstack([coords, np.zeros((n, 3 - nc))]) # Normalize to [-1, 1] so axis ranges and label sizes are consistent max_abs = np.abs(coords).max() if max_abs > 1e-8: coords = coords / max_abs return coords def _axis(title=""): """3D axis — minimal, no built-in grid (we draw our own floor grid).""" return dict( showgrid=False, zeroline=False, showticklabels=False, title=title, showspikes=False, showbackground=False, ) def layout_3d(axis_range=1.3, camera=None): """Shared Plotly 3D layout. Uses uirevision to preserve camera across updates.""" ax_x, ax_y, ax_z = _axis(), _axis(), _axis() fixed = [-axis_range, axis_range] ax_x["range"] = fixed ax_y["range"] = fixed ax_z["range"] = fixed default_camera = dict(eye=dict(x=1.0, y=1.0, z=0.8)) return dict( scene=dict( xaxis=ax_x, yaxis=ax_y, zaxis=ax_z, bgcolor="white", camera=camera or default_camera, aspectmode="cube", ), paper_bgcolor="white", margin=dict(l=0, r=0, t=10, b=10), showlegend=True, legend=dict( yanchor="top", y=0.99, xanchor="right", x=0.99, bgcolor="rgba(255,255,255,0.85)", font=dict(family="Inter, sans-serif", size=12), ), font=dict(family="Inter, sans-serif"), autosize=True, uirevision="keep", ) ARROW_WIDTH = 10 # line width in pixels ARROW_HEAD_LENGTH = 0.08 # arrowhead length in coordinate units ARROW_HEAD_WIDTH = 0.03 # arrowhead half-width in coordinate units def add_arrow(fig, px, py, pz, color, width=ARROW_WIDTH, head_length=ARROW_HEAD_LENGTH, head_width=ARROW_HEAD_WIDTH, sx=0, sy=0, sz=0, dash=None): """Draw a vector arrow from (sx,sy,sz) to (px,py,pz) with a flat arrowhead.""" start = np.array([sx, sy, sz]) tip = np.array([px, py, pz]) vec = tip - start length = np.linalg.norm(vec) if length < 1e-8: return d = vec / length # unit direction # Shorten line so it doesn't overlap arrowhead shorten = min(head_length, length * 0.3) end = tip - d * shorten # Vector line fig.add_trace(go.Scatter3d( x=[start[0], end[0]], y=[start[1], end[1]], z=[start[2], end[2]], mode="lines", line=dict(color=color, width=width, dash=dash), showlegend=False, hoverinfo="none", )) # Flat arrowhead using Mesh3d (diamond-cross pyramid, visible from all angles) up = np.array([0, 0, 1]) if abs(d[2]) < 0.9 else np.array([0, 1, 0]) p1 = np.cross(d, up) p1 = p1 / np.linalg.norm(p1) p2 = np.cross(d, p1) base = tip - d * head_length w1 = base + p1 * head_width w2 = base - p1 * head_width w3 = base + p2 * head_width w4 = base - p2 * head_width # 5 vertices: tip + 4 base points vx = [tip[0], w1[0], w2[0], w3[0], w4[0]] vy = [tip[1], w1[1], w2[1], w3[1], w4[1]] vz = [tip[2], w1[2], w2[2], w3[2], w4[2]] # 4 triangular faces forming the pyramid fig.add_trace(go.Mesh3d( x=vx, y=vy, z=vz, i=[0, 0, 0, 0], j=[1, 3, 2, 4], k=[3, 2, 4, 1], color=color, opacity=1.0, flatshading=True, lighting=dict(ambient=1, diffuse=0, specular=0, fresnel=0), showlegend=False, hoverinfo="none", )) def add_floor_grid(fig, range_val=1.3, step=0.25): """Draw a grid on the z=0 plane and axis lines through origin (3B1B style).""" grid_color = "rgba(99,52,141,0.18)" axis_color = "rgba(99,52,141,0.60)" # Grid lines on z=0 plane — batched with None separators for efficiency vals = np.arange(-range_val, range_val + step / 2, step) # Lines parallel to Y axis (varying x) xs, ys, zs = [], [], [] for v in vals: xs.extend([v, v, None]) ys.extend([-range_val, range_val, None]) zs.extend([0, 0, None]) fig.add_trace(go.Scatter3d( x=xs, y=ys, z=zs, mode="lines", line=dict(color=grid_color, width=1), showlegend=False, hoverinfo="none", )) # Lines parallel to X axis (varying y) xs, ys, zs = [], [], [] for v in vals: xs.extend([-range_val, range_val, None]) ys.extend([v, v, None]) zs.extend([0, 0, None]) fig.add_trace(go.Scatter3d( x=xs, y=ys, z=zs, mode="lines", line=dict(color=grid_color, width=1), showlegend=False, hoverinfo="none", )) # Three axis lines through origin for ax in [ ([-range_val, range_val], [0, 0], [0, 0]), # X ([0, 0], [-range_val, range_val], [0, 0]), # Y ([0, 0], [0, 0], [-range_val, range_val]), # Z ]: fig.add_trace(go.Scatter3d( x=list(ax[0]), y=list(ax[1]), z=list(ax[2]), mode="lines", line=dict(color=axis_color, width=2), showlegend=False, hoverinfo="none", )) def blank(msg): """Empty placeholder figure with a centered message.""" fig = go.Figure() fig.add_annotation( text=msg, xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16, color=GRAY, family="Inter, sans-serif"), ) fig.update_layout( xaxis_visible=False, yaxis_visible=False, height=560, paper_bgcolor="white", plot_bgcolor="white", margin=dict(l=0, r=0, t=0, b=0), ) return fig # ── Share URL helpers ──────────────────────────────────────── def _shorten_url(long_url): """Shorten a URL via Rebrandly API (using curl). Falls back to long URL.""" if not REBRANDLY_API_KEY or "localhost" in long_url: return long_url try: payload = json.dumps({ "destination": long_url, "domain": {"fullName": "go.ropavieja.org"}, }) result = subprocess.run( [ "curl", "-s", "-X", "POST", "https://api.rebrandly.com/v1/links", "-H", "Content-Type: application/json", "-H", f"apikey: {REBRANDLY_API_KEY}", "-d", payload, ], capture_output=True, text=True, timeout=10, ) if result.returncode != 0 or not result.stdout.strip(): return long_url data = json.loads(result.stdout) return f"https://{data['shortUrl']}" except (subprocess.TimeoutExpired, KeyError, json.JSONDecodeError, OSError) as exc: print(f"[share] Rebrandly error: {exc}") return long_url def _parse_camera(cam_str): """Parse compact camera string (ex,ey,ez[,cx,cy,cz,ux,uy,uz]) to Plotly camera dict.""" if not cam_str: return None try: vals = [float(v) for v in cam_str.split(",")] if len(vals) >= 3: camera = dict(eye=dict(x=vals[0], y=vals[1], z=vals[2])) if len(vals) >= 6: camera["center"] = dict(x=vals[3], y=vals[4], z=vals[5]) if len(vals) >= 9: camera["up"] = dict(x=vals[6], y=vals[7], z=vals[8]) return camera except (ValueError, IndexError): pass return None def _encode_camera(camera_json): """Encode Plotly camera JSON to compact string for URL params.""" if not camera_json: return "" try: cam = json.loads(camera_json) eye = cam.get("eye", {}) center = cam.get("center", {}) up = cam.get("up", {}) vals = [ eye.get("x", 1.5), eye.get("y", 1.5), eye.get("z", 1.2), center.get("x", 0), center.get("y", 0), center.get("z", 0), up.get("x", 0), up.get("y", 0), up.get("z", 1), ] return ",".join(f"{v:.2f}" for v in vals) except (json.JSONDecodeError, TypeError): return "" # ── Main visualization ─────────────────────────────────────── def explore(input_text, selected, hidden=None, camera=None, n_neighbors=None): """Unified 3D visualization of words and vector expressions. Args: input_text: Comma-separated words and/or expressions. selected: Currently selected item for neighbor display (or None). hidden: Set of labels to hide from rendering (MDS still uses all items). camera: Plotly camera dict to set initial view. n_neighbors: Number of nearest neighbors to show (default N_NEIGHBORS). Returns: (fig, status_md, radio_update, all_labels) """ if not input_text or not input_text.strip(): return ( blank("Enter words or expressions above to visualize in 3D"), "", gr.update(choices=[], value=None, visible=False), [], ) items, bad = parse_items(input_text) if not items: msg = "No valid items found." if bad: msg += f"
Not in vocabulary: {', '.join(bad)}" return blank(msg), "", gr.update(choices=[], value=None, visible=False), [] items = items[:12] labels = [item[0] for item in items] hidden = hidden or set() # Visible labels for radio choices (exclude hidden) visible_labels = [l for l in labels if l not in hidden] # No auto-select — user clicks radio to see neighbors if selected and selected != "(clear)" and selected in visible_labels: sel_idx = labels.index(selected) else: selected = None sel_idx = None # ── Collect all unique operand words for MDS ── all_words = [] word_set = set() for _, _, _, ops, _ in items: for w in ops: if w not in word_set: word_set.add(w) all_words.append(w) # Find nearest high-D word for each expression (used in status + MDS padding) expr_nearest = {} # label -> (word, similarity) for label, vec, is_expr, ops, ordered in items: if not is_expr: continue nearest = model.similar_by_vector(vec, topn=len(ops) + 5) for w, s in nearest: if w not in ops: expr_nearest[label] = (w, s) break # Pad with neighbor words if < 3 unique words (breaks MDS collinearity) helper_words = set() if len(all_words) < 3: # Try expression nearest-words first for label in expr_nearest: nw = expr_nearest[label][0] if nw not in word_set: word_set.add(nw) all_words.append(nw) helper_words.add(nw) if len(all_words) >= 3: break # Then try neighbors of any plain word if len(all_words) < 3: for w in list(all_words): if w in helper_words: continue for nw, _ in model.most_similar(w, topn=5): if nw not in word_set: word_set.add(nw) all_words.append(nw) helper_words.add(nw) if len(all_words) >= 3: break if len(all_words) >= 3: break # Gather neighbors if something is selected (and not hidden) nn = n_neighbors if n_neighbors is not None else N_NEIGHBORS nbr_data = [] if selected is not None: sel_item = items[sel_idx] if sel_item[2]: # expression raw = model.similar_by_vector(sel_item[1], topn=nn + 20) else: raw = model.most_similar(sel_item[0], topn=nn + 20) all_op_words = set() for _, _, _, ops, _ in items: all_op_words.update(ops) label_set = set(labels) nbr_data = [(w, s) for w, s in raw if w not in all_op_words and w not in label_set ][:nn] # ── MDS on all operand words + neighbors ── mds_words = all_words + [w for w, _ in nbr_data] if not mds_words: return blank("No valid words found."), "", gr.update( choices=[], value=None, visible=False), [] mds_vecs = np.array([model[w] for w in mds_words]) mds_coords = reduce_3d(mds_vecs) word_3d = {w: mds_coords[i] for i, w in enumerate(all_words)} nbr_coords = mds_coords[len(all_words):] if nbr_data else None # ── Compute expression results in 3D ── extra_points = [] # for dynamic axis range expr_info = {} # label -> visualization data for label, vec, is_expr, ops, ordered in items: if not is_expr: continue pos_words = [w for w, s, c in ordered if s == "+"] neg_words = [w for w, s, c in ordered if s == "-"] pos_coeffs = [c for w, s, c in ordered if s == "+"] neg_coeffs = [c for w, s, c in ordered if s == "-"] if len(neg_words) == 0 and len(pos_words) == 2: # Simple addition: chain tip-to-tail + gold result from origin a_3d = word_3d[pos_words[0]] b_3d = word_3d[pos_words[1]] result_3d = pos_coeffs[0] * a_3d + pos_coeffs[1] * b_3d extra_points.append(result_3d) expr_info[label] = ('add', pos_words[0], pos_words[1], pos_coeffs[0], pos_coeffs[1], result_3d) else: # General: chain through operands + gold result from origin cursor = np.zeros(3) chain = [] for w, s, coeff in ordered: prev = cursor.copy() c = word_3d[w] cursor = cursor + coeff * c if s == "+" else cursor - coeff * c chain.append((prev.copy(), cursor.copy(), w, s)) extra_points.append(cursor.copy()) expr_info[label] = ('chain', chain, cursor.copy()) # ── Dynamic axis range (computed from ALL items, not just visible) ── all_rendered = [word_3d[w] for w in all_words] + extra_points if nbr_coords is not None: all_rendered.extend(nbr_coords) if all_rendered: max_abs = np.abs(np.array(all_rendered)).max() axis_range = max(1.3, max_abs * 1.15) else: axis_range = 1.3 # ── Colors ── item_colors = [PALETTE[i % len(PALETTE)] for i in range(len(items))] # ── Build figure ── fig = go.Figure() add_floor_grid(fig, range_val=axis_range) annotations = [] def add_label(x, y, z, text, size=16, color=DARK, opacity=1.0, outward=0.07): """Add a 3D text label, shifted outward from origin to avoid overlapping vectors.""" pt = np.array([x, y, z]) norm = np.linalg.norm(pt) if norm > 1e-8 and outward > 0: pt = pt + (pt / norm) * outward annotations.append(dict( x=float(pt[0]), y=float(pt[1]), z=float(pt[2]), text=text, showarrow=False, font=dict(size=size, color=color, family="Inter, sans-serif"), opacity=opacity, yshift=12, )) for idx, (label, vec, is_expr, ops, ordered) in enumerate(items): # Skip hidden items if label in hidden: continue color = item_colors[idx] is_sel = (sel_idx is not None and idx == sel_idx) is_dim = (sel_idx is not None and idx != sel_idx) # Selection-aware styling arr_color = lighten(color, 0.5) if is_dim else color arr_width = 4 if is_dim else (12 if is_sel else ARROW_WIDTH) lbl_color = lighten(color, 0.5) if is_dim else color lbl_opacity = 0.7 if is_dim else 1.0 lbl_size = 18 if is_sel else (15 if is_dim else 16) gold = lighten(GOLD, 0.3) if is_dim else GOLD gold_width = 6 if is_dim else 12 if not is_expr: # ── Plain word: arrow from origin ── c = word_3d[label] add_arrow(fig, c[0], c[1], c[2], arr_color, width=arr_width) txt = f"{label}" if is_sel else label add_label(c[0], c[1], c[2], txt, size=lbl_size, color=lbl_color, opacity=lbl_opacity) else: info = expr_info[label] # ── Operand arrows from origin (full length — shows where the word IS) ── for w, s, coeff in ordered: c = word_3d[w] add_arrow(fig, c[0], c[1], c[2], arr_color, width=arr_width) txt = f"{w}" if is_sel else w add_label(c[0], c[1], c[2], txt, size=lbl_size, color=lbl_color, opacity=lbl_opacity) # ── Construction arrows with expression labels beside midpoint ── gold_lbl = f"{label}" gold_lbl_size = 14 if is_dim else 15 gold_lbl_color = "#b08820" if is_dim else "#8a6a10" def _label_beside(start, end): """Place label at midpoint of start→end, offset perpendicular to the arrow (on the side farther from origin).""" mid = (start + end) / 2 d = end - start length = np.linalg.norm(d) if length < 1e-8: return mid d = d / length up = np.array([0., 0., 1.]) if abs(d[2]) < 0.9 \ else np.array([0., 1., 0.]) perp = np.cross(d, up) pn = np.linalg.norm(perp) if pn < 1e-8: return mid perp = perp / pn * 0.12 # Pick side farther from origin (pushes label outward) if np.dot(mid + perp, mid + perp) >= np.dot(mid - perp, mid - perp): return mid + perp return mid - perp if info[0] == 'add': # Second operand drawn from tip of first (chain) # info = ('add', word_a, word_b, coeff_a, coeff_b, result_3d) a = word_3d[info[1]] coeff_a = info[3] coeff_b = info[4] result_3d = info[5] chain_start = coeff_a * a chain_color = lighten(color, 0.3) if is_dim else lighten(color, 0.2) add_arrow(fig, result_3d[0], result_3d[1], result_3d[2], chain_color, width=arr_width, sx=chain_start[0], sy=chain_start[1], sz=chain_start[2], dash="dot") # Gold result from origin add_arrow(fig, result_3d[0], result_3d[1], result_3d[2], gold, width=gold_width) origin = np.zeros(3) lpt = _label_beside(origin, result_3d) add_label(lpt[0], lpt[1], lpt[2], gold_lbl, size=gold_lbl_size, color=gold_lbl_color, opacity=lbl_opacity, outward=0) else: # chain chain_steps = info[1] result_3d = info[2] chain_color = lighten(color, 0.3) if is_dim else lighten(color, 0.2) for i, (start, end, w, s) in enumerate(chain_steps): if i == 0: continue # first step overlaps operand arrow add_arrow(fig, end[0], end[1], end[2], chain_color, width=arr_width, sx=start[0], sy=start[1], sz=start[2], dash="dot") # Gold result from origin add_arrow(fig, result_3d[0], result_3d[1], result_3d[2], gold, width=gold_width) origin = np.zeros(3) lpt = _label_beside(origin, result_3d) add_label(lpt[0], lpt[1], lpt[2], gold_lbl, size=gold_lbl_size, color=gold_lbl_color, opacity=lbl_opacity, outward=0) # (helper_words are invisible — only present for MDS geometry) # ── Neighbors ── if selected is not None and nbr_data and nbr_coords is not None: sel_color = item_colors[sel_idx] nbr_color = lighten(sel_color, 0.3) for i, (w, s) in enumerate(nbr_data): add_arrow(fig, nbr_coords[i, 0], nbr_coords[i, 1], nbr_coords[i, 2], nbr_color, width=ARROW_WIDTH, sx=0, sy=0, sz=0, dash="dot") add_label(nbr_coords[i, 0], nbr_coords[i, 1], nbr_coords[i, 2], w, size=16, color=DARK) fig.update_layout(**layout_3d(axis_range=axis_range, camera=camera), scene_annotations=annotations) # ── Status text ── n_visible = sum(1 for l, _, ie, _, _ in items if l not in hidden) n_hidden = len(hidden & set(labels)) n_words = sum(1 for l, _, ie, _, _ in items if not ie and l not in hidden) n_expr = sum(1 for l, _, ie, _, _ in items if ie and l not in hidden) parts = [] if n_words: parts.append(f"**{n_words} word{'s' if n_words != 1 else ''}**") if n_expr: parts.append(f"**{n_expr} expression{'s' if n_expr != 1 else ''}**") status = " + ".join(parts) + " in 3D" if parts else "Nothing visible" if n_hidden: status += f" · {n_hidden} hidden" if bad: status += f" · Not found: {', '.join(bad)}" for label in expr_nearest: if label not in hidden: w, s = expr_nearest[label] status += f" · **{label} ≈ {w}** ({s:.3f})" if nbr_data: status += f" · {len(nbr_data)} neighbors of **{selected}**" choices = ["(clear)"] + visible_labels return ( fig, status, gr.update(choices=choices, value=selected or "(clear)", visible=True), labels, ) # ── Gradio UI ──────────────────────────────────────────────── CSS = """ .gradio-container { max-width: 100% !important; padding: 0 2rem !important; } h1 { color: #63348d !important; } /* Example buttons — dark purple outline */ .purple-examples td { border: 2px solid #63348d !important; border-radius: 6px !important; color: #301848 !important; cursor: pointer !important; } .purple-examples td:hover { background: #ded9f4 !important; } /* Radio neighbor selector — purple text, white-on-purple when selected */ .nbr-radio label { color: #63348d !important; border: 1px solid #63348d !important; border-radius: 6px !important; } .nbr-radio label.selected { background: #63348d !important; color: #ffffff !important; } .nbr-radio label.selected * { color: #ffffff !important; } /* Visibility checkboxes — compact */ .vis-cbg label { color: #63348d !important; border: 1px solid #63348d !important; border-radius: 6px !important; } .vis-cbg label.selected { background: #63348d !important; color: #ffffff !important; } .vis-cbg label.selected * { color: #ffffff !important; } /* 3D viewport — full-width 16:9 with border */ .plot-viewport { border: 2px solid #ded9f4 !important; border-radius: 8px !important; } .plot-viewport .plot-container { aspect-ratio: 16 / 9 !important; width: 100% !important; } .plot-viewport .js-plotly-plot, .plot-viewport .plotly { width: 100% !important; height: 100% !important; } /* Neighbors dropdown — compact */ .nn-dropdown { max-width: 100px !important; } .nn-dropdown select { color: #63348d !important; } /* Hidden camera state textbox (visible=False prevents DOM rendering in Gradio 6) */ .camera-hidden { display: none !important; } /* Input fields — white for contrast */ textarea, input[type="text"] { background: #ffffff !important; } """ FORCE_LIGHT = """ """ _LIGHT = { "button_primary_background_fill": "#63348d", "button_primary_background_fill_hover": "#4a2769", "button_primary_text_color": "#ffffff", "block_background_fill": "#f3f0f7", "block_border_color": "#ded9f4", "body_background_fill": "#ffffff", "body_text_color": "#1a1a2e", "block_label_text_color": "#63348d", "block_title_text_color": "#63348d", "background_fill_primary": "#ffffff", "background_fill_secondary": "#f3f0f7", "input_background_fill": "#ffffff", "input_background_fill_focus": "#ffffff", "input_border_color": "#ded9f4", "input_border_color_focus": "#63348d", "input_placeholder_color": "#888888", "border_color_primary": "#ded9f4", "border_color_accent": "#63348d", "panel_background_fill": "#f3f0f7", } # Mirror every light value into _dark so dark mode looks identical _ALL = {} for k, v in _LIGHT.items(): _ALL[k] = v _ALL[k + "_dark"] = v THEME = gr.themes.Soft( primary_hue="purple", font=gr.themes.GoogleFont("Inter"), ).set(**_ALL) with gr.Blocks(title="Embedding Explorer") as demo: # ── State ── all_labels_state = gr.State([]) loading_share = gr.State(False) # suppress cascading events during share load camera_txt = gr.Textbox(elem_id="camera_txt", elem_classes=["camera-hidden"]) share_params = gr.State({}) # Force light mode fallback (head param covers most cases, this catches HF Spaces) gr.HTML('') gr.Markdown( "# Embedding Explorer\n" "Words in AI models are stored as **vectors** — long lists of numbers " "that encode meaning. Similar words have similar vectors. " "You can even do **vector math**: *king − man + woman ≈ queen*. " "This tool lets you explore these representations in 3D using " "[GloVe](https://nlp.stanford.edu/projects/glove/) word vectors " f"({len(VOCAB):,} words, {DIMS} dimensions)." ) gr.Markdown( "*Enter words to see them in 3D, or try vector arithmetic " "with `+` and `−`. Separate groups with commas. " "Click an item below the plot to see its nearest neighbors.*" ) with gr.Row(): with gr.Column(scale=2): exp_in = gr.Textbox( label="Words or expressions (comma-separated)", placeholder="dog cat fish or king - man + woman", lines=1, ) with gr.Column(scale=1): with gr.Row(): exp_btn = gr.Button("Explore", variant="primary") share_btn = gr.Button("Share", variant="secondary", scale=0, min_width=80) share_url = gr.Textbox(label="Share URL", visible=False, interactive=False, buttons=["copy"]) with gr.Column(elem_classes=["purple-examples"]): gr.Examples( examples=[[e] for e in EXAMPLES], inputs=[exp_in], label="Try these", ) exp_plot = gr.Plot(label="Embedding Space", elem_classes=["plot-viewport"]) exp_status = gr.Markdown("") vis_cbg = gr.CheckboxGroup( label="Visible items (uncheck to hide)", choices=[], value=[], visible=False, interactive=True, elem_classes=["vis-cbg"], ) with gr.Row(): exp_radio = gr.Radio( label="Click to see nearest neighbors", choices=[], value=None, visible=False, interactive=True, elem_classes=["nbr-radio"], ) nn_dropdown = gr.Dropdown( label="Neighbors", choices=[str(i) for i in range(3, 13)], value=str(N_NEIGHBORS), interactive=True, scale=0, min_width=90, elem_classes=["nn-dropdown"], ) # ── Event handlers ── def _parse_camera_json(camera_json): """Parse camera JSON string (from JS bridge) into Plotly camera dict.""" if not camera_json: return None try: return json.loads(camera_json) except (json.JSONDecodeError, TypeError): return None def _get_nn(nn_val): """Parse neighbor count from dropdown value.""" try: return int(nn_val) except (TypeError, ValueError): return N_NEIGHBORS def on_explore(input_text, nn_val=None): """Fresh explore — compute MDS, show all items, reset checkboxes. Supports @word syntax to auto-select a word for neighbors: dog cat fish @dog → plots all 3, shows dog's neighbors """ nn = _get_nn(nn_val) selected = None if input_text and "@" in input_text: match = re.search(r"@(\S+)", input_text) if match: selected = match.group(1).lower() input_text = re.sub(r"\s*@\S+", "", input_text).strip() fig, status, radio, labels = explore(input_text, selected, n_neighbors=nn) cbg = gr.update(choices=labels, value=labels, visible=bool(labels)) return fig, status, radio, labels, cbg, gr.update(value=input_text) def on_radio(input_text, selected, all_labels, visible, camera_json, is_loading, nn_val): """Neighbor selection — re-render with current visibility + camera.""" if is_loading: return gr.update(), gr.update(), gr.update(), False nn = _get_nn(nn_val) hidden = set(all_labels) - set(visible) if all_labels and visible else set() camera = _parse_camera_json(camera_json) fig, status, radio, _ = explore(input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn) return fig, status, radio, False def on_visibility(input_text, selected, all_labels, visible, camera_json, is_loading, nn_val): """Visibility toggle — re-render with updated hidden set + camera.""" if is_loading: return gr.update(), gr.update(), gr.update(), False nn = _get_nn(nn_val) hidden = set(all_labels) - set(visible) if all_labels else set() # If selected item is now hidden, clear selection if selected and selected != "(clear)" and selected in hidden: selected = None camera = _parse_camera_json(camera_json) fig, status, radio, _ = explore(input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn) return fig, status, radio, False def on_nn_change(input_text, selected, all_labels, visible, camera_json, is_loading, nn_val): """Neighbor count changed — re-render if a word is selected.""" if is_loading: return gr.update(), gr.update(), gr.update(), False if not selected or selected == "(clear)": return gr.update(), gr.update(), gr.update(), False nn = _get_nn(nn_val) hidden = set(all_labels) - set(visible) if all_labels and visible else set() camera = _parse_camera_json(camera_json) fig, status, radio, _ = explore(input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn) return fig, status, radio, False def on_share(input_text, selected, visible, camera_json, nn_val, request: gr.Request): """Build share URL encoding current state.""" params = {} if input_text and input_text.strip(): params["q"] = input_text.strip() if selected and selected != "(clear)": params["sel"] = selected # Only encode visibility if some items are hidden if visible is not None and isinstance(visible, list): params["vis"] = ",".join(visible) if camera_json: encoded = _encode_camera(camera_json) if encoded: params["cam"] = encoded nn = _get_nn(nn_val) if nn != N_NEIGHBORS: params["nn"] = str(nn) if not params.get("q"): return gr.update(value="Nothing to share", visible=True) # Build base URL from request (gets correct port for local dev) base_url = _BASE_URL if request: host = request.headers.get("host", "") if host: scheme = "https" if _SPACE_ID else "http" base_url = f"{scheme}://{host}/" long_url = base_url + "?" + urllib.parse.urlencode(params) # On localhost, just return the full URL (Rebrandly rejects non-public URLs) if "localhost" in long_url or "127.0.0.1" in long_url: return gr.update(value=long_url, visible=True) short = _shorten_url(long_url) return gr.update(value=short, visible=True) # ── Wire up events ── _EXAMPLE_SET = set(EXAMPLES) def on_input_change(input_text, nn_val): """Auto-explore when input matches an example (set by gr.Examples click).""" if input_text and input_text.strip() in _EXAMPLE_SET: return on_explore(input_text, nn_val) return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() exp_in.change( on_input_change, inputs=[exp_in, nn_dropdown], outputs=[exp_plot, exp_status, exp_radio, all_labels_state, vis_cbg, exp_in], ) exp_btn.click( on_explore, inputs=[exp_in, nn_dropdown], outputs=[exp_plot, exp_status, exp_radio, all_labels_state, vis_cbg, exp_in], ) exp_in.submit( on_explore, inputs=[exp_in, nn_dropdown], outputs=[exp_plot, exp_status, exp_radio, all_labels_state, vis_cbg, exp_in], ) # Radio + visibility + nn: camera_txt is kept up-to-date by polling script exp_radio.change( on_radio, inputs=[exp_in, exp_radio, all_labels_state, vis_cbg, camera_txt, loading_share, nn_dropdown], outputs=[exp_plot, exp_status, exp_radio, loading_share], ) vis_cbg.change( on_visibility, inputs=[exp_in, exp_radio, all_labels_state, vis_cbg, camera_txt, loading_share, nn_dropdown], outputs=[exp_plot, exp_status, exp_radio, loading_share], ) nn_dropdown.change( on_nn_change, inputs=[exp_in, exp_radio, all_labels_state, vis_cbg, camera_txt, loading_share, nn_dropdown], outputs=[exp_plot, exp_status, exp_radio, loading_share], ) # Share: camera_txt kept up-to-date by polling script share_btn.click( fn=on_share, inputs=[exp_in, exp_radio, vis_cbg, camera_txt, nn_dropdown], outputs=[share_url], ) # ── Share URL loading ── def load_share_params(request: gr.Request): """Step 1: Parse query params from URL.""" qp = dict(request.query_params) if request else {} return qp def apply_share_params(params): """Step 2: Apply share params — set input, run explore, apply visibility + camera + nn.""" if not params or "q" not in params: # Check if nn param is present even without q nn_str = params.get("nn") if params else None nn_update = gr.update(value=nn_str) if nn_str else gr.update() return ( gr.update(), # exp_in gr.update(), # exp_plot gr.update(), # exp_status gr.update(), # exp_radio gr.update(), # vis_cbg [], # all_labels_state gr.update(), # camera_txt False, # loading_share nn_update, # nn_dropdown ) input_text = params.get("q", "") selected = params.get("sel") if selected == "": selected = None vis_str = params.get("vis") cam_str = params.get("cam") nn_str = params.get("nn") camera = _parse_camera(cam_str) nn = int(nn_str) if nn_str and nn_str.isdigit() else None # First explore with all items visible to get labels _, _, _, labels = explore(input_text, None, camera=camera, n_neighbors=nn) # Apply visibility if vis_str: visible = [v.strip() for v in vis_str.split(",")] hidden = set(labels) - set(visible) else: visible = labels hidden = set() fig, status, radio, _ = explore( input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn ) cbg = gr.update( choices=labels, value=visible, visible=bool(labels), ) # Pre-populate camera_txt so subsequent re-renders preserve camera camera_json = json.dumps(camera) if camera else "" nn_update = gr.update(value=str(nn)) if nn else gr.update() return ( gr.update(value=input_text), fig, status, radio, cbg, labels, gr.update(value=camera_json), True, # loading_share — suppress cascading events nn_update, ) demo.load( fn=load_share_params, outputs=[share_params], ).then( fn=apply_share_params, inputs=[share_params], outputs=[exp_in, exp_plot, exp_status, exp_radio, vis_cbg, all_labels_state, camera_txt, loading_share, nn_dropdown], ) demo.launch(theme=THEME, css=CSS, head=FORCE_LIGHT)