Spaces:
Running
Running
| """ | |
| Embedding Explorer β Interactive word vector visualization | |
| Responsible AI: Technology, Power, and Justice | |
| Huston-Tillotson University | |
| Enter words or vector expressions (comma-separated). Each item becomes | |
| an arrow in 3D. Click an item to see its nearest neighbors. | |
| Configuration (HuggingFace Space environment variables): | |
| EXAMPLES β JSON list of example inputs (words or expressions) | |
| N_NEIGHBORS β number of neighbors to show on click (default 8) | |
| """ | |
| import os | |
| import json | |
| import re | |
| import warnings | |
| import urllib.parse | |
| import subprocess | |
| import pandas # noqa: F401 β import before plotly to avoid circular import | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import gradio as gr | |
| warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn") | |
| # ββ Configuration (all changeable via HF Space env vars) βββββ | |
| EXAMPLES = json.loads(os.environ.get("EXAMPLES", json.dumps([ | |
| "dog cat fish car truck", | |
| "paris france berlin germany tokyo japan", | |
| "man woman king queen prince princess", | |
| "man - woman, uncle - aunt, man woman uncle aunt", | |
| "aunt - woman + man, man woman uncle aunt", | |
| "nephew - man + woman, man woman nephew niece", | |
| "king - man + woman, man woman king queen", | |
| "paris - france + italy, paris france italy rome", | |
| "sushi - japan + germany, sushi japan germany bratwurst", | |
| "hitler - germany + italy, germany italy hitler mussolini", | |
| ]))) | |
| N_NEIGHBORS = int(os.environ.get("N_NEIGHBORS", "4")) | |
| # ββ Share URL infrastructure βββββββββββββββββββββββββββββββββ | |
| REBRANDLY_API_KEY = os.environ.get("REBRANDLY_API_KEY", "") | |
| _SPACE_ID = os.environ.get("SPACE_ID", "") | |
| if _SPACE_ID: | |
| _owner, _name = _SPACE_ID.split("/") | |
| _BASE_URL = f"https://{_owner}-{_name}.hf.space/" | |
| else: | |
| _BASE_URL = "http://localhost:7860/" | |
| # ββ Course design system colors ββββββββββββββββββββββββββββββ | |
| PURPLE = "#63348d" | |
| PURPLE_LIGHT = "#ded9f4" | |
| PURPLE_DARK = "#301848" | |
| GOLD = "#f0c040" | |
| PINK = "#de95a0" | |
| DARK = "#1a1a2e" | |
| GRAY = "#888888" | |
| BG = "#fafafa" | |
| # Categorical palette β darkest shade from each design-system color family | |
| # 10 shades per family from color_palette.md (darkest β lightest) | |
| _SHADES = { | |
| "purple": ["#64348d", "#6f3ba4", "#7942bb", "#8455c5", "#906acd", | |
| "#9d80d6", "#ab95de", "#baabe5", "#cbc1ec", "#ddd7f4"], | |
| "blue": ["#344b8d", "#3b58a4", "#4266bb", "#5579c5", "#6a8dcd", | |
| "#809fd6", "#95b2de", "#abc3e5", "#c1d4ec", "#d7e5f4"], | |
| "green": ["#348d64", "#3ba470", "#42bb7c", "#55c588", "#6acd95", | |
| "#80d6a2", "#95deb0", "#abe5bf", "#c1eccf", "#d7f4e0"], | |
| "red": ["#8d3437", "#a43b41", "#bb424a", "#c5555f", "#cd6a75", | |
| "#d6808a", "#de95a0", "#e5abb4", "#ecc1c9", "#f4d7dd"], | |
| "yellow": ["#8d7734", "#a48b3b", "#bb9f42", "#c5ac55", "#cdb86a", | |
| "#d6c480", "#decf95", "#e5daab", "#ece5c1", "#f4efd7"], | |
| "orange": ["#8d5534", "#a4633b", "#bb7242", "#c58355", "#cd946a", | |
| "#d6a580", "#deb695", "#e5c6ab", "#ecd6c1", "#f4e5d7"], | |
| } | |
| # Darkest shade from each family β up to 12 items cycle through these 6 | |
| PALETTE = [ | |
| _SHADES["purple"][0], | |
| _SHADES["blue"][0], | |
| _SHADES["red"][0], | |
| _SHADES["yellow"][0], | |
| _SHADES["green"][0], | |
| _SHADES["orange"][0], | |
| ] | |
| # Map each darkest color to its shade family for lookup | |
| _COLOR_FAMILY = {shades[0]: shades for shades in _SHADES.values()} | |
| def lighten(hex_color, amount=0.3): | |
| """Lighten a color using the design-system palette shades. | |
| Maps amount (0.0β1.0) to palette shade index. Falls back to | |
| arithmetic blending for colors not in the palette. | |
| """ | |
| base = hex_color.lower() | |
| if base in _COLOR_FAMILY: | |
| shades = _COLOR_FAMILY[base] | |
| idx = min(int(amount * len(shades)), len(shades) - 1) | |
| return shades[idx] | |
| # Fallback for non-palette colors (e.g., GOLD) | |
| h = hex_color.lstrip("#") | |
| r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) | |
| r = int(r + (255 - r) * amount) | |
| g = int(g + (255 - g) * amount) | |
| b = int(b + (255 - b) * amount) | |
| return f"#{r:02x}{g:02x}{b:02x}" | |
| # ββ Load GloVe embeddings on startup βββββββββββββββββββββββββ | |
| import time | |
| import gensim.downloader as api | |
| from gensim.models import KeyedVectors | |
| # Native binary cache β loads ~10x faster than gensim's text format | |
| _CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".cache") | |
| _CACHE_PATH = os.path.join(_CACHE_DIR, "glove-wiki-gigaword-300.kv") | |
| def load_model(name="glove-wiki-gigaword-300", retries=5): | |
| """Load GloVe vectors. Uses native binary cache for fast startup after first run.""" | |
| # Fast path: load from native binary cache (memory-mapped) | |
| if os.path.exists(_CACHE_PATH): | |
| print("=" * 60) | |
| print("Loading GloVe vectors from cache (memory-mapped)...") | |
| print("=" * 60) | |
| t0 = time.time() | |
| m = KeyedVectors.load(_CACHE_PATH, mmap="r") | |
| print(f"Loaded in {time.time() - t0:.1f}s") | |
| return m | |
| # Slow path: download via gensim, then save native cache | |
| for attempt in range(1, retries + 1): | |
| try: | |
| print("=" * 60) | |
| print(f"Downloading GloVe word vectors ({name})...") | |
| if attempt == 1: | |
| print("First run only β ~376 MB download. Will cache for fast startup.") | |
| else: | |
| print(f"Retry {attempt}/{retries}...") | |
| print("=" * 60) | |
| m = api.load(name) | |
| # Save native binary cache for next time | |
| os.makedirs(_CACHE_DIR, exist_ok=True) | |
| m.save(_CACHE_PATH) | |
| print(f"Cached to {_CACHE_DIR} for fast startup next time.") | |
| return m | |
| except Exception as e: | |
| print(f"Attempt {attempt} failed: {e}") | |
| if attempt < retries: | |
| wait = 2 ** attempt | |
| print(f"Retrying in {wait}s...") | |
| time.sleep(wait) | |
| else: | |
| raise | |
| model = load_model() | |
| VOCAB = set(model.key_to_index.keys()) | |
| DIMS = model.vector_size | |
| print(f"Ready: {len(VOCAB):,} words, {DIMS} dimensions each") | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_expression(expr): | |
| """Parse 'king - man + woman' β (positives, negatives, ordered). | |
| ordered is [(word, sign, coeff), ...] for display formatting. | |
| Supports fractional coefficients: '0.5 king - 0.3 man + 1.5 woman'. | |
| """ | |
| tokens = re.findall(r"\d*\.?\d+|[a-z']+|[+-]", expr.lower()) | |
| pos, neg, ordered = [], [], [] | |
| sign = "+" | |
| coeff = 1.0 | |
| for t in tokens: | |
| if t in "+-": | |
| sign = t | |
| coeff = 1.0 | |
| elif re.match(r"^\d*\.?\d+$", t): | |
| coeff = float(t) | |
| elif t in VOCAB: | |
| (pos if sign == "+" else neg).append((t, coeff)) | |
| ordered.append((t, sign, coeff)) | |
| coeff = 1.0 | |
| return pos, neg, ordered | |
| def _coeff_str(c): | |
| """Format coefficient for display. Returns '' for 1.0, '0.5\u00b7' otherwise.""" | |
| if c == 1.0: | |
| return "" | |
| if c == int(c): | |
| return f"{int(c)}\u00b7" | |
| return f"{c:g}\u00b7" | |
| def parse_items(text): | |
| """Parse comma-separated input into items (words or vector expressions). | |
| Returns (items, bad_words) where each item is: | |
| (label, vector, is_expr, operand_words, ordered_ops_or_None) | |
| """ | |
| if not text or not text.strip(): | |
| return [], [] | |
| items = [] | |
| bad_words = [] | |
| seen_labels = set() | |
| # Split by comma β parts | |
| parts = [p.strip() for p in text.split(",")] | |
| for part in parts: | |
| if not part: | |
| continue | |
| # Detect expression: contains + or - between word characters or digits | |
| if re.search(r"(?:[a-z']|\d)\s*[+\-]\s*(?:[a-z']|\d)", part.lower()): | |
| # It's an arithmetic expression | |
| pos, neg, ordered = parse_expression(part) | |
| if len(pos) + len(neg) < 2: | |
| # Check for bad words | |
| all_tokens = re.findall(r"[a-z']+", part.lower()) | |
| bad_words.extend(t for t in all_tokens if t not in VOCAB) | |
| continue | |
| # Compute result vector | |
| vec = np.zeros(DIMS) | |
| for w, c in pos: | |
| vec += c * model[w] | |
| for w, c in neg: | |
| vec -= c * model[w] | |
| # Build label | |
| label_parts = [] | |
| for w, s, c in ordered: | |
| cstr = _coeff_str(c) | |
| if not label_parts: | |
| label_parts.append(f"{cstr}{w}") | |
| elif s == "+": | |
| label_parts.append(f"+ {cstr}{w}") | |
| else: | |
| label_parts.append(f"β {cstr}{w}") | |
| label = " ".join(label_parts) | |
| if label not in seen_labels: | |
| seen_labels.add(label) | |
| operand_words = set(w for w, c in pos + neg) | |
| items.append((label, vec, True, operand_words, list(ordered))) | |
| else: | |
| # Plain words β each word is a separate item | |
| words = re.split(r"\s+", part.lower().strip()) | |
| for w in words: | |
| w = w.strip() | |
| if not w: | |
| continue | |
| if w not in VOCAB: | |
| bad_words.append(w) | |
| continue | |
| if w not in seen_labels: | |
| seen_labels.add(w) | |
| items.append((w, model[w], False, {w}, None)) | |
| return items, bad_words | |
| def reduce_3d(vectors): | |
| """MDS (cosine distance) β 3D. Normalizes to [-1, 1] for consistent label sizing.""" | |
| from sklearn.manifold import MDS | |
| from sklearn.metrics.pairwise import cosine_distances | |
| n = len(vectors) | |
| if n < 2: | |
| return np.zeros((n, 3)) | |
| dist = cosine_distances(vectors) | |
| nc = min(3, n) | |
| mds = MDS(n_components=nc, dissimilarity="precomputed", | |
| random_state=42, normalized_stress="auto", max_iter=300) | |
| coords = mds.fit_transform(dist) | |
| if nc < 3: | |
| coords = np.hstack([coords, np.zeros((n, 3 - nc))]) | |
| # Normalize to [-1, 1] so axis ranges and label sizes are consistent | |
| max_abs = np.abs(coords).max() | |
| if max_abs > 1e-8: | |
| coords = coords / max_abs | |
| return coords | |
| def _axis(title=""): | |
| """3D axis β minimal, no built-in grid (we draw our own floor grid).""" | |
| return dict( | |
| showgrid=False, | |
| zeroline=False, | |
| showticklabels=False, | |
| title=title, | |
| showspikes=False, | |
| showbackground=False, | |
| ) | |
| def layout_3d(axis_range=1.3, camera=None): | |
| """Shared Plotly 3D layout. Uses uirevision to preserve camera across updates.""" | |
| ax_x, ax_y, ax_z = _axis(), _axis(), _axis() | |
| fixed = [-axis_range, axis_range] | |
| ax_x["range"] = fixed | |
| ax_y["range"] = fixed | |
| ax_z["range"] = fixed | |
| default_camera = dict(eye=dict(x=1.0, y=1.0, z=0.8)) | |
| return dict( | |
| scene=dict( | |
| xaxis=ax_x, | |
| yaxis=ax_y, | |
| zaxis=ax_z, | |
| bgcolor="white", | |
| camera=camera or default_camera, | |
| aspectmode="cube", | |
| ), | |
| paper_bgcolor="white", | |
| margin=dict(l=0, r=0, t=10, b=10), | |
| showlegend=True, | |
| legend=dict( | |
| yanchor="top", y=0.99, xanchor="right", x=0.99, | |
| bgcolor="rgba(255,255,255,0.85)", | |
| font=dict(family="Inter, sans-serif", size=12), | |
| ), | |
| font=dict(family="Inter, sans-serif"), | |
| autosize=True, | |
| uirevision="keep", | |
| ) | |
| ARROW_WIDTH = 10 # line width in pixels | |
| ARROW_HEAD_LENGTH = 0.08 # arrowhead length in coordinate units | |
| ARROW_HEAD_WIDTH = 0.03 # arrowhead half-width in coordinate units | |
| def add_arrow(fig, px, py, pz, color, width=ARROW_WIDTH, | |
| head_length=ARROW_HEAD_LENGTH, head_width=ARROW_HEAD_WIDTH, | |
| sx=0, sy=0, sz=0, dash=None): | |
| """Draw a vector arrow from (sx,sy,sz) to (px,py,pz) with a flat arrowhead.""" | |
| start = np.array([sx, sy, sz]) | |
| tip = np.array([px, py, pz]) | |
| vec = tip - start | |
| length = np.linalg.norm(vec) | |
| if length < 1e-8: | |
| return | |
| d = vec / length # unit direction | |
| # Shorten line so it doesn't overlap arrowhead | |
| shorten = min(head_length, length * 0.3) | |
| end = tip - d * shorten | |
| # Vector line | |
| fig.add_trace(go.Scatter3d( | |
| x=[start[0], end[0]], y=[start[1], end[1]], z=[start[2], end[2]], | |
| mode="lines", line=dict(color=color, width=width, dash=dash), | |
| showlegend=False, hoverinfo="none", | |
| )) | |
| # Flat arrowhead using Mesh3d (diamond-cross pyramid, visible from all angles) | |
| up = np.array([0, 0, 1]) if abs(d[2]) < 0.9 else np.array([0, 1, 0]) | |
| p1 = np.cross(d, up) | |
| p1 = p1 / np.linalg.norm(p1) | |
| p2 = np.cross(d, p1) | |
| base = tip - d * head_length | |
| w1 = base + p1 * head_width | |
| w2 = base - p1 * head_width | |
| w3 = base + p2 * head_width | |
| w4 = base - p2 * head_width | |
| # 5 vertices: tip + 4 base points | |
| vx = [tip[0], w1[0], w2[0], w3[0], w4[0]] | |
| vy = [tip[1], w1[1], w2[1], w3[1], w4[1]] | |
| vz = [tip[2], w1[2], w2[2], w3[2], w4[2]] | |
| # 4 triangular faces forming the pyramid | |
| fig.add_trace(go.Mesh3d( | |
| x=vx, y=vy, z=vz, | |
| i=[0, 0, 0, 0], j=[1, 3, 2, 4], k=[3, 2, 4, 1], | |
| color=color, opacity=1.0, | |
| flatshading=True, | |
| lighting=dict(ambient=1, diffuse=0, specular=0, fresnel=0), | |
| showlegend=False, hoverinfo="none", | |
| )) | |
| def add_floor_grid(fig, range_val=1.3, step=0.25): | |
| """Draw a grid on the z=0 plane and axis lines through origin (3B1B style).""" | |
| grid_color = "rgba(99,52,141,0.18)" | |
| axis_color = "rgba(99,52,141,0.60)" | |
| # Grid lines on z=0 plane β batched with None separators for efficiency | |
| vals = np.arange(-range_val, range_val + step / 2, step) | |
| # Lines parallel to Y axis (varying x) | |
| xs, ys, zs = [], [], [] | |
| for v in vals: | |
| xs.extend([v, v, None]) | |
| ys.extend([-range_val, range_val, None]) | |
| zs.extend([0, 0, None]) | |
| fig.add_trace(go.Scatter3d( | |
| x=xs, y=ys, z=zs, | |
| mode="lines", line=dict(color=grid_color, width=1), | |
| showlegend=False, hoverinfo="none", | |
| )) | |
| # Lines parallel to X axis (varying y) | |
| xs, ys, zs = [], [], [] | |
| for v in vals: | |
| xs.extend([-range_val, range_val, None]) | |
| ys.extend([v, v, None]) | |
| zs.extend([0, 0, None]) | |
| fig.add_trace(go.Scatter3d( | |
| x=xs, y=ys, z=zs, | |
| mode="lines", line=dict(color=grid_color, width=1), | |
| showlegend=False, hoverinfo="none", | |
| )) | |
| # Three axis lines through origin | |
| for ax in [ | |
| ([-range_val, range_val], [0, 0], [0, 0]), # X | |
| ([0, 0], [-range_val, range_val], [0, 0]), # Y | |
| ([0, 0], [0, 0], [-range_val, range_val]), # Z | |
| ]: | |
| fig.add_trace(go.Scatter3d( | |
| x=list(ax[0]), y=list(ax[1]), z=list(ax[2]), | |
| mode="lines", line=dict(color=axis_color, width=2), | |
| showlegend=False, hoverinfo="none", | |
| )) | |
| def blank(msg): | |
| """Empty placeholder figure with a centered message.""" | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text=msg, xref="paper", yref="paper", x=0.5, y=0.5, | |
| showarrow=False, | |
| font=dict(size=16, color=GRAY, family="Inter, sans-serif"), | |
| ) | |
| fig.update_layout( | |
| xaxis_visible=False, yaxis_visible=False, | |
| height=560, paper_bgcolor="white", plot_bgcolor="white", | |
| margin=dict(l=0, r=0, t=0, b=0), | |
| ) | |
| return fig | |
| # ββ Share URL helpers ββββββββββββββββββββββββββββββββββββββββ | |
| def _shorten_url(long_url): | |
| """Shorten a URL via Rebrandly API (using curl). Falls back to long URL.""" | |
| if not REBRANDLY_API_KEY or "localhost" in long_url: | |
| return long_url | |
| try: | |
| payload = json.dumps({ | |
| "destination": long_url, | |
| "domain": {"fullName": "go.ropavieja.org"}, | |
| }) | |
| result = subprocess.run( | |
| [ | |
| "curl", "-s", "-X", "POST", | |
| "https://api.rebrandly.com/v1/links", | |
| "-H", "Content-Type: application/json", | |
| "-H", f"apikey: {REBRANDLY_API_KEY}", | |
| "-d", payload, | |
| ], | |
| capture_output=True, text=True, timeout=10, | |
| ) | |
| if result.returncode != 0 or not result.stdout.strip(): | |
| return long_url | |
| data = json.loads(result.stdout) | |
| return f"https://{data['shortUrl']}" | |
| except (subprocess.TimeoutExpired, KeyError, json.JSONDecodeError, OSError) as exc: | |
| print(f"[share] Rebrandly error: {exc}") | |
| return long_url | |
| def _parse_camera(cam_str): | |
| """Parse compact camera string (ex,ey,ez[,cx,cy,cz,ux,uy,uz]) to Plotly camera dict.""" | |
| if not cam_str: | |
| return None | |
| try: | |
| vals = [float(v) for v in cam_str.split(",")] | |
| if len(vals) >= 3: | |
| camera = dict(eye=dict(x=vals[0], y=vals[1], z=vals[2])) | |
| if len(vals) >= 6: | |
| camera["center"] = dict(x=vals[3], y=vals[4], z=vals[5]) | |
| if len(vals) >= 9: | |
| camera["up"] = dict(x=vals[6], y=vals[7], z=vals[8]) | |
| return camera | |
| except (ValueError, IndexError): | |
| pass | |
| return None | |
| def _encode_camera(camera_json): | |
| """Encode Plotly camera JSON to compact string for URL params.""" | |
| if not camera_json: | |
| return "" | |
| try: | |
| cam = json.loads(camera_json) | |
| eye = cam.get("eye", {}) | |
| center = cam.get("center", {}) | |
| up = cam.get("up", {}) | |
| vals = [ | |
| eye.get("x", 1.5), eye.get("y", 1.5), eye.get("z", 1.2), | |
| center.get("x", 0), center.get("y", 0), center.get("z", 0), | |
| up.get("x", 0), up.get("y", 0), up.get("z", 1), | |
| ] | |
| return ",".join(f"{v:.2f}" for v in vals) | |
| except (json.JSONDecodeError, TypeError): | |
| return "" | |
| # ββ Main visualization βββββββββββββββββββββββββββββββββββββββ | |
| def explore(input_text, selected, hidden=None, camera=None, n_neighbors=None): | |
| """Unified 3D visualization of words and vector expressions. | |
| Args: | |
| input_text: Comma-separated words and/or expressions. | |
| selected: Currently selected item for neighbor display (or None). | |
| hidden: Set of labels to hide from rendering (MDS still uses all items). | |
| camera: Plotly camera dict to set initial view. | |
| n_neighbors: Number of nearest neighbors to show (default N_NEIGHBORS). | |
| Returns: | |
| (fig, status_md, radio_update, all_labels) | |
| """ | |
| if not input_text or not input_text.strip(): | |
| return ( | |
| blank("Enter words or expressions above to visualize in 3D"), | |
| "", | |
| gr.update(choices=[], value=None, visible=False), | |
| [], | |
| ) | |
| items, bad = parse_items(input_text) | |
| if not items: | |
| msg = "No valid items found." | |
| if bad: | |
| msg += f"<br>Not in vocabulary: {', '.join(bad)}" | |
| return blank(msg), "", gr.update(choices=[], value=None, visible=False), [] | |
| items = items[:12] | |
| labels = [item[0] for item in items] | |
| hidden = hidden or set() | |
| # Visible labels for radio choices (exclude hidden) | |
| visible_labels = [l for l in labels if l not in hidden] | |
| # No auto-select β user clicks radio to see neighbors | |
| if selected and selected != "(clear)" and selected in visible_labels: | |
| sel_idx = labels.index(selected) | |
| else: | |
| selected = None | |
| sel_idx = None | |
| # ββ Collect all unique operand words for MDS ββ | |
| all_words = [] | |
| word_set = set() | |
| for _, _, _, ops, _ in items: | |
| for w in ops: | |
| if w not in word_set: | |
| word_set.add(w) | |
| all_words.append(w) | |
| # Find nearest high-D word for each expression (used in status + MDS padding) | |
| expr_nearest = {} # label -> (word, similarity) | |
| for label, vec, is_expr, ops, ordered in items: | |
| if not is_expr: | |
| continue | |
| nearest = model.similar_by_vector(vec, topn=len(ops) + 5) | |
| for w, s in nearest: | |
| if w not in ops: | |
| expr_nearest[label] = (w, s) | |
| break | |
| # Pad with neighbor words if < 3 unique words (breaks MDS collinearity) | |
| helper_words = set() | |
| if len(all_words) < 3: | |
| # Try expression nearest-words first | |
| for label in expr_nearest: | |
| nw = expr_nearest[label][0] | |
| if nw not in word_set: | |
| word_set.add(nw) | |
| all_words.append(nw) | |
| helper_words.add(nw) | |
| if len(all_words) >= 3: | |
| break | |
| # Then try neighbors of any plain word | |
| if len(all_words) < 3: | |
| for w in list(all_words): | |
| if w in helper_words: | |
| continue | |
| for nw, _ in model.most_similar(w, topn=5): | |
| if nw not in word_set: | |
| word_set.add(nw) | |
| all_words.append(nw) | |
| helper_words.add(nw) | |
| if len(all_words) >= 3: | |
| break | |
| if len(all_words) >= 3: | |
| break | |
| # Gather neighbors if something is selected (and not hidden) | |
| nn = n_neighbors if n_neighbors is not None else N_NEIGHBORS | |
| nbr_data = [] | |
| if selected is not None: | |
| sel_item = items[sel_idx] | |
| if sel_item[2]: # expression | |
| raw = model.similar_by_vector(sel_item[1], topn=nn + 20) | |
| else: | |
| raw = model.most_similar(sel_item[0], topn=nn + 20) | |
| all_op_words = set() | |
| for _, _, _, ops, _ in items: | |
| all_op_words.update(ops) | |
| label_set = set(labels) | |
| nbr_data = [(w, s) for w, s in raw | |
| if w not in all_op_words and w not in label_set | |
| ][:nn] | |
| # ββ MDS on all operand words + neighbors ββ | |
| mds_words = all_words + [w for w, _ in nbr_data] | |
| if not mds_words: | |
| return blank("No valid words found."), "", gr.update( | |
| choices=[], value=None, visible=False), [] | |
| mds_vecs = np.array([model[w] for w in mds_words]) | |
| mds_coords = reduce_3d(mds_vecs) | |
| word_3d = {w: mds_coords[i] for i, w in enumerate(all_words)} | |
| nbr_coords = mds_coords[len(all_words):] if nbr_data else None | |
| # ββ Compute expression results in 3D ββ | |
| extra_points = [] # for dynamic axis range | |
| expr_info = {} # label -> visualization data | |
| for label, vec, is_expr, ops, ordered in items: | |
| if not is_expr: | |
| continue | |
| pos_words = [w for w, s, c in ordered if s == "+"] | |
| neg_words = [w for w, s, c in ordered if s == "-"] | |
| pos_coeffs = [c for w, s, c in ordered if s == "+"] | |
| neg_coeffs = [c for w, s, c in ordered if s == "-"] | |
| if len(neg_words) == 0 and len(pos_words) == 2: | |
| # Simple addition: chain tip-to-tail + gold result from origin | |
| a_3d = word_3d[pos_words[0]] | |
| b_3d = word_3d[pos_words[1]] | |
| result_3d = pos_coeffs[0] * a_3d + pos_coeffs[1] * b_3d | |
| extra_points.append(result_3d) | |
| expr_info[label] = ('add', pos_words[0], pos_words[1], | |
| pos_coeffs[0], pos_coeffs[1], result_3d) | |
| else: | |
| # General: chain through operands + gold result from origin | |
| cursor = np.zeros(3) | |
| chain = [] | |
| for w, s, coeff in ordered: | |
| prev = cursor.copy() | |
| c = word_3d[w] | |
| cursor = cursor + coeff * c if s == "+" else cursor - coeff * c | |
| chain.append((prev.copy(), cursor.copy(), w, s)) | |
| extra_points.append(cursor.copy()) | |
| expr_info[label] = ('chain', chain, cursor.copy()) | |
| # ββ Dynamic axis range (computed from ALL items, not just visible) ββ | |
| all_rendered = [word_3d[w] for w in all_words] + extra_points | |
| if nbr_coords is not None: | |
| all_rendered.extend(nbr_coords) | |
| if all_rendered: | |
| max_abs = np.abs(np.array(all_rendered)).max() | |
| axis_range = max(1.3, max_abs * 1.15) | |
| else: | |
| axis_range = 1.3 | |
| # ββ Colors ββ | |
| item_colors = [PALETTE[i % len(PALETTE)] for i in range(len(items))] | |
| # ββ Build figure ββ | |
| fig = go.Figure() | |
| add_floor_grid(fig, range_val=axis_range) | |
| annotations = [] | |
| def add_label(x, y, z, text, size=16, color=DARK, opacity=1.0, | |
| outward=0.07): | |
| """Add a 3D text label, shifted outward from origin to avoid overlapping vectors.""" | |
| pt = np.array([x, y, z]) | |
| norm = np.linalg.norm(pt) | |
| if norm > 1e-8 and outward > 0: | |
| pt = pt + (pt / norm) * outward | |
| annotations.append(dict( | |
| x=float(pt[0]), y=float(pt[1]), z=float(pt[2]), | |
| text=text, showarrow=False, | |
| font=dict(size=size, color=color, family="Inter, sans-serif"), | |
| opacity=opacity, yshift=12, | |
| )) | |
| for idx, (label, vec, is_expr, ops, ordered) in enumerate(items): | |
| # Skip hidden items | |
| if label in hidden: | |
| continue | |
| color = item_colors[idx] | |
| is_sel = (sel_idx is not None and idx == sel_idx) | |
| is_dim = (sel_idx is not None and idx != sel_idx) | |
| # Selection-aware styling | |
| arr_color = lighten(color, 0.5) if is_dim else color | |
| arr_width = 4 if is_dim else (12 if is_sel else ARROW_WIDTH) | |
| lbl_color = lighten(color, 0.5) if is_dim else color | |
| lbl_opacity = 0.7 if is_dim else 1.0 | |
| lbl_size = 18 if is_sel else (15 if is_dim else 16) | |
| gold = lighten(GOLD, 0.3) if is_dim else GOLD | |
| gold_width = 6 if is_dim else 12 | |
| if not is_expr: | |
| # ββ Plain word: arrow from origin ββ | |
| c = word_3d[label] | |
| add_arrow(fig, c[0], c[1], c[2], arr_color, width=arr_width) | |
| txt = f"<b>{label}</b>" if is_sel else label | |
| add_label(c[0], c[1], c[2], txt, size=lbl_size, | |
| color=lbl_color, opacity=lbl_opacity) | |
| else: | |
| info = expr_info[label] | |
| # ββ Operand arrows from origin (full length β shows where the word IS) ββ | |
| for w, s, coeff in ordered: | |
| c = word_3d[w] | |
| add_arrow(fig, c[0], c[1], c[2], arr_color, width=arr_width) | |
| txt = f"<b>{w}</b>" if is_sel else w | |
| add_label(c[0], c[1], c[2], txt, size=lbl_size, | |
| color=lbl_color, opacity=lbl_opacity) | |
| # ββ Construction arrows with expression labels beside midpoint ββ | |
| gold_lbl = f"<i>{label}</i>" | |
| gold_lbl_size = 14 if is_dim else 15 | |
| gold_lbl_color = "#b08820" if is_dim else "#8a6a10" | |
| def _label_beside(start, end): | |
| """Place label at midpoint of startβend, offset perpendicular | |
| to the arrow (on the side farther from origin).""" | |
| mid = (start + end) / 2 | |
| d = end - start | |
| length = np.linalg.norm(d) | |
| if length < 1e-8: | |
| return mid | |
| d = d / length | |
| up = np.array([0., 0., 1.]) if abs(d[2]) < 0.9 \ | |
| else np.array([0., 1., 0.]) | |
| perp = np.cross(d, up) | |
| pn = np.linalg.norm(perp) | |
| if pn < 1e-8: | |
| return mid | |
| perp = perp / pn * 0.12 | |
| # Pick side farther from origin (pushes label outward) | |
| if np.dot(mid + perp, mid + perp) >= np.dot(mid - perp, mid - perp): | |
| return mid + perp | |
| return mid - perp | |
| if info[0] == 'add': | |
| # Second operand drawn from tip of first (chain) | |
| # info = ('add', word_a, word_b, coeff_a, coeff_b, result_3d) | |
| a = word_3d[info[1]] | |
| coeff_a = info[3] | |
| coeff_b = info[4] | |
| result_3d = info[5] | |
| chain_start = coeff_a * a | |
| chain_color = lighten(color, 0.3) if is_dim else lighten(color, 0.2) | |
| add_arrow(fig, result_3d[0], result_3d[1], result_3d[2], | |
| chain_color, width=arr_width, | |
| sx=chain_start[0], sy=chain_start[1], sz=chain_start[2], | |
| dash="dot") | |
| # Gold result from origin | |
| add_arrow(fig, result_3d[0], result_3d[1], result_3d[2], | |
| gold, width=gold_width) | |
| origin = np.zeros(3) | |
| lpt = _label_beside(origin, result_3d) | |
| add_label(lpt[0], lpt[1], lpt[2], gold_lbl, | |
| size=gold_lbl_size, color=gold_lbl_color, | |
| opacity=lbl_opacity, outward=0) | |
| else: # chain | |
| chain_steps = info[1] | |
| result_3d = info[2] | |
| chain_color = lighten(color, 0.3) if is_dim else lighten(color, 0.2) | |
| for i, (start, end, w, s) in enumerate(chain_steps): | |
| if i == 0: | |
| continue # first step overlaps operand arrow | |
| add_arrow(fig, end[0], end[1], end[2], | |
| chain_color, width=arr_width, | |
| sx=start[0], sy=start[1], sz=start[2], dash="dot") | |
| # Gold result from origin | |
| add_arrow(fig, result_3d[0], result_3d[1], result_3d[2], | |
| gold, width=gold_width) | |
| origin = np.zeros(3) | |
| lpt = _label_beside(origin, result_3d) | |
| add_label(lpt[0], lpt[1], lpt[2], gold_lbl, | |
| size=gold_lbl_size, color=gold_lbl_color, | |
| opacity=lbl_opacity, outward=0) | |
| # (helper_words are invisible β only present for MDS geometry) | |
| # ββ Neighbors ββ | |
| if selected is not None and nbr_data and nbr_coords is not None: | |
| sel_color = item_colors[sel_idx] | |
| nbr_color = lighten(sel_color, 0.3) | |
| for i, (w, s) in enumerate(nbr_data): | |
| add_arrow(fig, | |
| nbr_coords[i, 0], nbr_coords[i, 1], nbr_coords[i, 2], | |
| nbr_color, width=ARROW_WIDTH, | |
| sx=0, sy=0, sz=0, dash="dot") | |
| add_label(nbr_coords[i, 0], nbr_coords[i, 1], nbr_coords[i, 2], | |
| w, size=16, color=DARK) | |
| fig.update_layout(**layout_3d(axis_range=axis_range, camera=camera), | |
| scene_annotations=annotations) | |
| # ββ Status text ββ | |
| n_visible = sum(1 for l, _, ie, _, _ in items if l not in hidden) | |
| n_hidden = len(hidden & set(labels)) | |
| n_words = sum(1 for l, _, ie, _, _ in items if not ie and l not in hidden) | |
| n_expr = sum(1 for l, _, ie, _, _ in items if ie and l not in hidden) | |
| parts = [] | |
| if n_words: | |
| parts.append(f"**{n_words} word{'s' if n_words != 1 else ''}**") | |
| if n_expr: | |
| parts.append(f"**{n_expr} expression{'s' if n_expr != 1 else ''}**") | |
| status = " + ".join(parts) + " in 3D" if parts else "Nothing visible" | |
| if n_hidden: | |
| status += f" Β· {n_hidden} hidden" | |
| if bad: | |
| status += f" Β· Not found: {', '.join(bad)}" | |
| for label in expr_nearest: | |
| if label not in hidden: | |
| w, s = expr_nearest[label] | |
| status += f" Β· **{label} β {w}** ({s:.3f})" | |
| if nbr_data: | |
| status += f" Β· {len(nbr_data)} neighbors of **{selected}**" | |
| choices = ["(clear)"] + visible_labels | |
| return ( | |
| fig, | |
| status, | |
| gr.update(choices=choices, value=selected or "(clear)", visible=True), | |
| labels, | |
| ) | |
| # ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| .gradio-container { max-width: 100% !important; padding: 0 2rem !important; } | |
| h1 { color: #63348d !important; } | |
| /* Example buttons β dark purple outline */ | |
| .purple-examples td { | |
| border: 2px solid #63348d !important; | |
| border-radius: 6px !important; | |
| color: #301848 !important; | |
| cursor: pointer !important; | |
| } | |
| .purple-examples td:hover { | |
| background: #ded9f4 !important; | |
| } | |
| /* Radio neighbor selector β purple text, white-on-purple when selected */ | |
| .nbr-radio label { | |
| color: #63348d !important; | |
| border: 1px solid #63348d !important; | |
| border-radius: 6px !important; | |
| } | |
| .nbr-radio label.selected { | |
| background: #63348d !important; | |
| color: #ffffff !important; | |
| } | |
| .nbr-radio label.selected * { | |
| color: #ffffff !important; | |
| } | |
| /* Visibility checkboxes β compact */ | |
| .vis-cbg label { | |
| color: #63348d !important; | |
| border: 1px solid #63348d !important; | |
| border-radius: 6px !important; | |
| } | |
| .vis-cbg label.selected { | |
| background: #63348d !important; | |
| color: #ffffff !important; | |
| } | |
| .vis-cbg label.selected * { | |
| color: #ffffff !important; | |
| } | |
| /* 3D viewport β full-width 16:9 with border */ | |
| .plot-viewport { | |
| border: 2px solid #ded9f4 !important; | |
| border-radius: 8px !important; | |
| } | |
| .plot-viewport .plot-container { | |
| aspect-ratio: 16 / 9 !important; | |
| width: 100% !important; | |
| } | |
| .plot-viewport .js-plotly-plot, .plot-viewport .plotly { | |
| width: 100% !important; | |
| height: 100% !important; | |
| } | |
| /* Neighbors dropdown β compact */ | |
| .nn-dropdown { | |
| max-width: 100px !important; | |
| } | |
| .nn-dropdown select { | |
| color: #63348d !important; | |
| } | |
| /* Hidden camera state textbox (visible=False prevents DOM rendering in Gradio 6) */ | |
| .camera-hidden { display: none !important; } | |
| /* Input fields β white for contrast */ | |
| textarea, input[type="text"] { | |
| background: #ffffff !important; | |
| } | |
| """ | |
| FORCE_LIGHT = """ | |
| <script> | |
| if(!location.search.includes("__theme=light")){ | |
| const u=new URL(location);u.searchParams.set("__theme","light");location.replace(u); | |
| } | |
| </script> | |
| <script> | |
| // Camera tracker β polls Plotly camera into hidden textbox for Gradio to read | |
| (function() { | |
| console.log('[cam] Camera tracker script loaded'); | |
| var attempts = 0; | |
| var interval = setInterval(function() { | |
| attempts++; | |
| var plots = document.querySelectorAll('.js-plotly-plot'); | |
| if (plots.length === 0) { | |
| if (attempts % 20 === 0) console.log('[cam] waiting for plot...', attempts); | |
| return; | |
| } | |
| var plot = plots[0]; | |
| if (!plot._fullLayout || !plot._fullLayout.scene || !plot._fullLayout.scene._scene) { | |
| if (attempts % 20 === 0) console.log('[cam] plot found but no scene yet'); | |
| return; | |
| } | |
| try { | |
| var cam = plot._fullLayout.scene._scene.getCamera(); | |
| var el = document.querySelector('#camera_txt textarea, #camera_txt input'); | |
| if (!el) { | |
| console.log('[cam] cannot find #camera_txt element'); | |
| return; | |
| } | |
| var val = JSON.stringify(cam); | |
| if (el.value !== val) { | |
| el.value = val; | |
| var nativeInputValueSetter = Object.getOwnPropertyDescriptor( | |
| window.HTMLTextAreaElement.prototype, 'value' | |
| ) || Object.getOwnPropertyDescriptor( | |
| window.HTMLInputElement.prototype, 'value' | |
| ); | |
| if (nativeInputValueSetter && nativeInputValueSetter.set) { | |
| nativeInputValueSetter.set.call(el, val); | |
| } | |
| el.dispatchEvent(new Event('input', {bubbles: true})); | |
| el.dispatchEvent(new Event('change', {bubbles: true})); | |
| } | |
| } catch(e) { | |
| console.log('[cam] error:', e); | |
| } | |
| }, 500); | |
| })(); | |
| </script> | |
| <script> | |
| // Clear share params from URL after load (so refresh doesn't re-apply) | |
| if (new URL(location).searchParams.has('q')) { | |
| var _clearId = setInterval(function() { | |
| if (document.querySelector('.js-plotly-plot')) { | |
| var clean = new URL(location.pathname, location.origin); | |
| clean.searchParams.set('__theme', 'light'); | |
| history.replaceState(null, '', clean.toString()); | |
| clearInterval(_clearId); | |
| } | |
| }, 500); | |
| } | |
| </script> | |
| """ | |
| _LIGHT = { | |
| "button_primary_background_fill": "#63348d", | |
| "button_primary_background_fill_hover": "#4a2769", | |
| "button_primary_text_color": "#ffffff", | |
| "block_background_fill": "#f3f0f7", | |
| "block_border_color": "#ded9f4", | |
| "body_background_fill": "#ffffff", | |
| "body_text_color": "#1a1a2e", | |
| "block_label_text_color": "#63348d", | |
| "block_title_text_color": "#63348d", | |
| "background_fill_primary": "#ffffff", | |
| "background_fill_secondary": "#f3f0f7", | |
| "input_background_fill": "#ffffff", | |
| "input_background_fill_focus": "#ffffff", | |
| "input_border_color": "#ded9f4", | |
| "input_border_color_focus": "#63348d", | |
| "input_placeholder_color": "#888888", | |
| "border_color_primary": "#ded9f4", | |
| "border_color_accent": "#63348d", | |
| "panel_background_fill": "#f3f0f7", | |
| } | |
| # Mirror every light value into _dark so dark mode looks identical | |
| _ALL = {} | |
| for k, v in _LIGHT.items(): | |
| _ALL[k] = v | |
| _ALL[k + "_dark"] = v | |
| THEME = gr.themes.Soft( | |
| primary_hue="purple", | |
| font=gr.themes.GoogleFont("Inter"), | |
| ).set(**_ALL) | |
| with gr.Blocks(title="Embedding Explorer") as demo: | |
| # ββ State ββ | |
| all_labels_state = gr.State([]) | |
| loading_share = gr.State(False) # suppress cascading events during share load | |
| camera_txt = gr.Textbox(elem_id="camera_txt", elem_classes=["camera-hidden"]) | |
| share_params = gr.State({}) | |
| # Force light mode fallback (head param covers most cases, this catches HF Spaces) | |
| gr.HTML('<script>if(!location.search.includes("__theme=light"))' | |
| '{const u=new URL(location);u.searchParams.set("__theme","light");' | |
| 'location.replace(u)}</script>') | |
| gr.Markdown( | |
| "# Embedding Explorer\n" | |
| "Words in AI models are stored as **vectors** β long lists of numbers " | |
| "that encode meaning. Similar words have similar vectors. " | |
| "You can even do **vector math**: *king β man + woman β queen*. " | |
| "This tool lets you explore these representations in 3D using " | |
| "[GloVe](https://nlp.stanford.edu/projects/glove/) word vectors " | |
| f"({len(VOCAB):,} words, {DIMS} dimensions)." | |
| ) | |
| gr.Markdown( | |
| "*Enter words to see them in 3D, or try vector arithmetic " | |
| "with `+` and `β`. Separate groups with commas. " | |
| "Click an item below the plot to see its nearest neighbors.*" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| exp_in = gr.Textbox( | |
| label="Words or expressions (comma-separated)", | |
| placeholder="dog cat fish or king - man + woman", | |
| lines=1, | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| exp_btn = gr.Button("Explore", variant="primary") | |
| share_btn = gr.Button("Share", variant="secondary", | |
| scale=0, min_width=80) | |
| share_url = gr.Textbox(label="Share URL", visible=False, | |
| interactive=False, buttons=["copy"]) | |
| with gr.Column(elem_classes=["purple-examples"]): | |
| gr.Examples( | |
| examples=[[e] for e in EXAMPLES], | |
| inputs=[exp_in], | |
| label="Try these", | |
| ) | |
| exp_plot = gr.Plot(label="Embedding Space", elem_classes=["plot-viewport"]) | |
| exp_status = gr.Markdown("") | |
| vis_cbg = gr.CheckboxGroup( | |
| label="Visible items (uncheck to hide)", | |
| choices=[], value=[], | |
| visible=False, interactive=True, | |
| elem_classes=["vis-cbg"], | |
| ) | |
| with gr.Row(): | |
| exp_radio = gr.Radio( | |
| label="Click to see nearest neighbors", | |
| choices=[], value=None, | |
| visible=False, interactive=True, | |
| elem_classes=["nbr-radio"], | |
| ) | |
| nn_dropdown = gr.Dropdown( | |
| label="Neighbors", | |
| choices=[str(i) for i in range(3, 13)], | |
| value=str(N_NEIGHBORS), | |
| interactive=True, | |
| scale=0, | |
| min_width=90, | |
| elem_classes=["nn-dropdown"], | |
| ) | |
| # ββ Event handlers ββ | |
| def _parse_camera_json(camera_json): | |
| """Parse camera JSON string (from JS bridge) into Plotly camera dict.""" | |
| if not camera_json: | |
| return None | |
| try: | |
| return json.loads(camera_json) | |
| except (json.JSONDecodeError, TypeError): | |
| return None | |
| def _get_nn(nn_val): | |
| """Parse neighbor count from dropdown value.""" | |
| try: | |
| return int(nn_val) | |
| except (TypeError, ValueError): | |
| return N_NEIGHBORS | |
| def on_explore(input_text, nn_val=None): | |
| """Fresh explore β compute MDS, show all items, reset checkboxes. | |
| Supports @word syntax to auto-select a word for neighbors: | |
| dog cat fish @dog β plots all 3, shows dog's neighbors | |
| """ | |
| nn = _get_nn(nn_val) | |
| selected = None | |
| if input_text and "@" in input_text: | |
| match = re.search(r"@(\S+)", input_text) | |
| if match: | |
| selected = match.group(1).lower() | |
| input_text = re.sub(r"\s*@\S+", "", input_text).strip() | |
| fig, status, radio, labels = explore(input_text, selected, n_neighbors=nn) | |
| cbg = gr.update(choices=labels, value=labels, visible=bool(labels)) | |
| return fig, status, radio, labels, cbg, gr.update(value=input_text) | |
| def on_radio(input_text, selected, all_labels, visible, camera_json, is_loading, nn_val): | |
| """Neighbor selection β re-render with current visibility + camera.""" | |
| if is_loading: | |
| return gr.update(), gr.update(), gr.update(), False | |
| nn = _get_nn(nn_val) | |
| hidden = set(all_labels) - set(visible) if all_labels and visible else set() | |
| camera = _parse_camera_json(camera_json) | |
| fig, status, radio, _ = explore(input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn) | |
| return fig, status, radio, False | |
| def on_visibility(input_text, selected, all_labels, visible, camera_json, is_loading, nn_val): | |
| """Visibility toggle β re-render with updated hidden set + camera.""" | |
| if is_loading: | |
| return gr.update(), gr.update(), gr.update(), False | |
| nn = _get_nn(nn_val) | |
| hidden = set(all_labels) - set(visible) if all_labels else set() | |
| # If selected item is now hidden, clear selection | |
| if selected and selected != "(clear)" and selected in hidden: | |
| selected = None | |
| camera = _parse_camera_json(camera_json) | |
| fig, status, radio, _ = explore(input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn) | |
| return fig, status, radio, False | |
| def on_nn_change(input_text, selected, all_labels, visible, camera_json, is_loading, nn_val): | |
| """Neighbor count changed β re-render if a word is selected.""" | |
| if is_loading: | |
| return gr.update(), gr.update(), gr.update(), False | |
| if not selected or selected == "(clear)": | |
| return gr.update(), gr.update(), gr.update(), False | |
| nn = _get_nn(nn_val) | |
| hidden = set(all_labels) - set(visible) if all_labels and visible else set() | |
| camera = _parse_camera_json(camera_json) | |
| fig, status, radio, _ = explore(input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn) | |
| return fig, status, radio, False | |
| def on_share(input_text, selected, visible, camera_json, nn_val, request: gr.Request): | |
| """Build share URL encoding current state.""" | |
| params = {} | |
| if input_text and input_text.strip(): | |
| params["q"] = input_text.strip() | |
| if selected and selected != "(clear)": | |
| params["sel"] = selected | |
| # Only encode visibility if some items are hidden | |
| if visible is not None and isinstance(visible, list): | |
| params["vis"] = ",".join(visible) | |
| if camera_json: | |
| encoded = _encode_camera(camera_json) | |
| if encoded: | |
| params["cam"] = encoded | |
| nn = _get_nn(nn_val) | |
| if nn != N_NEIGHBORS: | |
| params["nn"] = str(nn) | |
| if not params.get("q"): | |
| return gr.update(value="Nothing to share", visible=True) | |
| # Build base URL from request (gets correct port for local dev) | |
| base_url = _BASE_URL | |
| if request: | |
| host = request.headers.get("host", "") | |
| if host: | |
| scheme = "https" if _SPACE_ID else "http" | |
| base_url = f"{scheme}://{host}/" | |
| long_url = base_url + "?" + urllib.parse.urlencode(params) | |
| # On localhost, just return the full URL (Rebrandly rejects non-public URLs) | |
| if "localhost" in long_url or "127.0.0.1" in long_url: | |
| return gr.update(value=long_url, visible=True) | |
| short = _shorten_url(long_url) | |
| return gr.update(value=short, visible=True) | |
| # ββ Wire up events ββ | |
| _EXAMPLE_SET = set(EXAMPLES) | |
| def on_input_change(input_text, nn_val): | |
| """Auto-explore when input matches an example (set by gr.Examples click).""" | |
| if input_text and input_text.strip() in _EXAMPLE_SET: | |
| return on_explore(input_text, nn_val) | |
| return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() | |
| exp_in.change( | |
| on_input_change, | |
| inputs=[exp_in, nn_dropdown], | |
| outputs=[exp_plot, exp_status, exp_radio, all_labels_state, vis_cbg, exp_in], | |
| ) | |
| exp_btn.click( | |
| on_explore, | |
| inputs=[exp_in, nn_dropdown], | |
| outputs=[exp_plot, exp_status, exp_radio, all_labels_state, vis_cbg, exp_in], | |
| ) | |
| exp_in.submit( | |
| on_explore, | |
| inputs=[exp_in, nn_dropdown], | |
| outputs=[exp_plot, exp_status, exp_radio, all_labels_state, vis_cbg, exp_in], | |
| ) | |
| # Radio + visibility + nn: camera_txt is kept up-to-date by polling script | |
| exp_radio.change( | |
| on_radio, | |
| inputs=[exp_in, exp_radio, all_labels_state, vis_cbg, camera_txt, loading_share, nn_dropdown], | |
| outputs=[exp_plot, exp_status, exp_radio, loading_share], | |
| ) | |
| vis_cbg.change( | |
| on_visibility, | |
| inputs=[exp_in, exp_radio, all_labels_state, vis_cbg, camera_txt, loading_share, nn_dropdown], | |
| outputs=[exp_plot, exp_status, exp_radio, loading_share], | |
| ) | |
| nn_dropdown.change( | |
| on_nn_change, | |
| inputs=[exp_in, exp_radio, all_labels_state, vis_cbg, camera_txt, loading_share, nn_dropdown], | |
| outputs=[exp_plot, exp_status, exp_radio, loading_share], | |
| ) | |
| # Share: camera_txt kept up-to-date by polling script | |
| share_btn.click( | |
| fn=on_share, | |
| inputs=[exp_in, exp_radio, vis_cbg, camera_txt, nn_dropdown], | |
| outputs=[share_url], | |
| ) | |
| # ββ Share URL loading ββ | |
| def load_share_params(request: gr.Request): | |
| """Step 1: Parse query params from URL.""" | |
| qp = dict(request.query_params) if request else {} | |
| return qp | |
| def apply_share_params(params): | |
| """Step 2: Apply share params β set input, run explore, apply visibility + camera + nn.""" | |
| if not params or "q" not in params: | |
| # Check if nn param is present even without q | |
| nn_str = params.get("nn") if params else None | |
| nn_update = gr.update(value=nn_str) if nn_str else gr.update() | |
| return ( | |
| gr.update(), # exp_in | |
| gr.update(), # exp_plot | |
| gr.update(), # exp_status | |
| gr.update(), # exp_radio | |
| gr.update(), # vis_cbg | |
| [], # all_labels_state | |
| gr.update(), # camera_txt | |
| False, # loading_share | |
| nn_update, # nn_dropdown | |
| ) | |
| input_text = params.get("q", "") | |
| selected = params.get("sel") | |
| if selected == "": | |
| selected = None | |
| vis_str = params.get("vis") | |
| cam_str = params.get("cam") | |
| nn_str = params.get("nn") | |
| camera = _parse_camera(cam_str) | |
| nn = int(nn_str) if nn_str and nn_str.isdigit() else None | |
| # First explore with all items visible to get labels | |
| _, _, _, labels = explore(input_text, None, camera=camera, n_neighbors=nn) | |
| # Apply visibility | |
| if vis_str: | |
| visible = [v.strip() for v in vis_str.split(",")] | |
| hidden = set(labels) - set(visible) | |
| else: | |
| visible = labels | |
| hidden = set() | |
| fig, status, radio, _ = explore( | |
| input_text, selected, hidden=hidden or None, camera=camera, n_neighbors=nn | |
| ) | |
| cbg = gr.update( | |
| choices=labels, | |
| value=visible, | |
| visible=bool(labels), | |
| ) | |
| # Pre-populate camera_txt so subsequent re-renders preserve camera | |
| camera_json = json.dumps(camera) if camera else "" | |
| nn_update = gr.update(value=str(nn)) if nn else gr.update() | |
| return ( | |
| gr.update(value=input_text), | |
| fig, | |
| status, | |
| radio, | |
| cbg, | |
| labels, | |
| gr.update(value=camera_json), | |
| True, # loading_share β suppress cascading events | |
| nn_update, | |
| ) | |
| demo.load( | |
| fn=load_share_params, | |
| outputs=[share_params], | |
| ).then( | |
| fn=apply_share_params, | |
| inputs=[share_params], | |
| outputs=[exp_in, exp_plot, exp_status, exp_radio, vis_cbg, all_labels_state, camera_txt, loading_share, nn_dropdown], | |
| ) | |
| demo.launch(theme=THEME, css=CSS, head=FORCE_LIGHT) | |