import gradio as gr import numpy as np import json import plotly.graph_objects as go from sklearn.decomposition import PCA from sklearn.manifold import TSNE from microembeddings import ( load_text8, build_vocab, prepare_corpus, build_neg_table, train, normalize, most_similar, analogy, describe_text8_source ) # --- Global state --- state = {"W": None, "W_norm": None, "word2idx": None, "idx2word": None, "losses": []} def load_pretrained(): """Load pre-trained embeddings if available.""" try: W = np.load("pretrained_W.npy") with open("pretrained_vocab.json") as f: meta = json.load(f) vocab = meta["vocab"] state["W"] = W state["W_norm"] = normalize(W) state["word2idx"] = {w: i for i, w in enumerate(vocab)} state["idx2word"] = {i: w for i, w in enumerate(vocab)} state["losses"] = meta.get("losses", []) return ( "Loaded pre-trained full-text8 gensim vectors: " f"{W.shape[0]} words x {W.shape[1]} dims" ) except FileNotFoundError: return "No pre-trained embeddings found. Train from scratch!" # --- Tab 1: Train --- def run_training(embed_dim, window_size, learning_rate, num_neg, progress=gr.Progress()): fig = go.Figure() try: progress(0, desc="Loading text8...") words = load_text8(500000) word2idx, idx2word, freqs = build_vocab(words) corpus = prepare_corpus(words, word2idx, freqs) neg_dist = build_neg_table(freqs) state["word2idx"] = word2idx state["idx2word"] = idx2word losses = [] def callback(epoch, i, total, loss): pct = i / total progress(pct, desc=f"Epoch {epoch+1}: loss={loss:.4f}") losses.append(loss) W, _ = train(corpus, len(word2idx), neg_dist, epochs=3, embed_dim=int(embed_dim), lr=learning_rate, window=int(window_size), num_neg=int(num_neg), callback=callback) state["W"] = W state["W_norm"] = normalize(W) state["losses"] = losses fig.add_trace(go.Scatter(y=losses, mode="lines", name="Loss", line=dict(color="#4F46E5"))) fig.update_layout(title="Training Loss", xaxis_title="Step", yaxis_title="Loss", template="plotly_white") return fig, f"Done! {W.shape[0]} words x {W.shape[1]} dims" except Exception as exc: fig.update_layout(title="Training unavailable", template="plotly_white") return fig, f"Training failed: {exc}" # --- Tab 2: Explore --- def explore_embeddings(method, num_words, category): if state["W"] is None: return None n = min(int(num_words), len(state["idx2word"])) W_sub = state["W"][:n] words_sub = [state["idx2word"][i] for i in range(n)] if method == "PCA": coords = PCA(n_components=2).fit_transform(W_sub) else: coords = TSNE(n_components=2, perplexity=min(30, n - 1), random_state=42).fit_transform(W_sub) categories = { "Countries": ["france", "germany", "italy", "spain", "china", "japan", "india", "russia", "england", "canada", "brazil", "australia", "mexico", "korea"], "Animals": ["dog", "cat", "horse", "bird", "fish", "lion", "bear", "wolf", "snake", "elephant"], "Numbers": ["one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"], "Colors": ["red", "blue", "green", "yellow", "black", "white", "brown", "gold", "silver"], } highlight_words = set(categories.get(category, [])) colors, sizes = [], [] for w in words_sub: if w in highlight_words: colors.append("#E11D48") sizes.append(10) else: colors.append("#93C5FD") sizes.append(4) fig = go.Figure() fig.add_trace(go.Scatter( x=coords[:, 0], y=coords[:, 1], mode="markers", marker=dict(size=sizes, color=colors, opacity=0.7), text=words_sub, hoverinfo="text" )) for i, w in enumerate(words_sub): if w in highlight_words: fig.add_annotation(x=coords[i, 0], y=coords[i, 1], text=w, showarrow=False, yshift=12, font=dict(size=10, color="#E11D48")) fig.update_layout(title=f"Embedding Space ({method})", template="plotly_white", width=800, height=600, showlegend=False) return fig # --- Tab 3: Analogies --- def solve_analogy(a, b, c): if state["W_norm"] is None: return "Train or load embeddings first!", None a, b, c = a.strip().lower(), b.strip().lower(), c.strip().lower() results = analogy(a, b, c, state["W_norm"], state["word2idx"], state["idx2word"]) if not results: missing = [w for w in [a, b, c] if w not in state["word2idx"]] return f"Word(s) not in vocabulary: {', '.join(missing)}", None text = f"{a} is to {b} as {c} is to...\n\n" text += "\n".join(f" {w}: {s:.4f}" for w, s in results) words_r, sims_r = zip(*results) fig = go.Figure(go.Bar(x=list(sims_r), y=list(words_r), orientation="h", marker_color="#4F46E5")) fig.update_layout(title=f"{a} : {b} :: {c} : ?", xaxis_title="Cosine similarity", template="plotly_white", yaxis=dict(autorange="reversed")) return text, fig # --- Tab 4: Nearest Neighbors --- def find_neighbors(word): if state["W_norm"] is None: return "Train or load embeddings first!", None word = word.strip().lower() results = most_similar(word, state["W_norm"], state["word2idx"], state["idx2word"]) if not results: return f"'{word}' not in vocabulary", None text = "\n".join(f"{w}: {s:.4f}" for w, s in results) words_r, sims_r = zip(*results) fig = go.Figure(go.Bar(x=list(sims_r), y=list(words_r), orientation="h", marker_color="#4F46E5")) fig.update_layout(title=f"Nearest neighbors of '{word}'", xaxis_title="Cosine similarity", template="plotly_white", yaxis=dict(autorange="reversed")) return text, fig # --- Build UI --- load_msg = load_pretrained() corpus_msg = describe_text8_source() with gr.Blocks(title="microembeddings", theme=gr.themes.Soft()) as demo: gr.Markdown( "# microembeddings\n" "*Word2Vec skip-gram from scratch — train, explore, and play with word vectors*\n\n" "Companion to the blog post: " "[microembeddings: Understanding Word Vectors from Scratch]" "(https://kshreyas.dev/post/microembeddings/)" ) gr.Markdown(f"*{load_msg}*") gr.Markdown( "*Preloaded vectors use gensim Word2Vec on the full 17M-word text8 corpus.* " "*The Train tab reruns the NumPy implementation on a 500k-word subset so it stays interactive.*" ) with gr.Tabs(): with gr.Tab("Train"): gr.Markdown( "Train word embeddings from scratch on text8 (cleaned Wikipedia).\n\n" f"{corpus_msg}" ) with gr.Row(): dim_slider = gr.Slider(25, 100, value=50, step=25, label="Embedding dimension") win_slider = gr.Slider(1, 10, value=5, step=1, label="Window size") with gr.Row(): lr_slider = gr.Slider(0.001, 0.05, value=0.025, step=0.001, label="Learning rate") neg_slider = gr.Slider(1, 15, value=5, step=1, label="Negative samples") train_btn = gr.Button("Train", variant="primary") train_status = gr.Textbox(label="Status", interactive=False) loss_plot = gr.Plot(label="Training Loss") train_btn.click(run_training, [dim_slider, win_slider, lr_slider, neg_slider], [loss_plot, train_status]) with gr.Tab("Explore"): gr.Markdown( "Visualize the embedding space in 2D. " "Similar words cluster together." ) with gr.Row(): method_radio = gr.Radio(["PCA", "t-SNE"], value="PCA", label="Projection method") num_words_slider = gr.Slider(100, 500, value=200, step=50, label="Number of words") cat_dropdown = gr.Dropdown( ["None", "Countries", "Animals", "Numbers", "Colors"], value="None", label="Highlight category" ) explore_btn = gr.Button("Visualize", variant="primary") explore_plot = gr.Plot(label="Embedding Space") explore_btn.click(explore_embeddings, [method_radio, num_words_slider, cat_dropdown], explore_plot) with gr.Tab("Analogies"): gr.Markdown( "Word vector arithmetic: **A is to B as C is to ?**\n\n" "Computed as: `B - A + C ≈ ?`" ) with gr.Row(): a_input = gr.Textbox(label="A", placeholder="man", value="man") b_input = gr.Textbox(label="B", placeholder="king", value="king") c_input = gr.Textbox(label="C", placeholder="woman", value="woman") analogy_btn = gr.Button("Solve", variant="primary") gr.Examples( [["man", "king", "woman"], ["france", "paris", "germany"], ["big", "bigger", "small"]], inputs=[a_input, b_input, c_input] ) analogy_text = gr.Textbox(label="Results", interactive=False, lines=6) analogy_plot = gr.Plot(label="Similarity") analogy_btn.click(solve_analogy, [a_input, b_input, c_input], [analogy_text, analogy_plot]) with gr.Tab("Nearest Neighbors"): gr.Markdown("Find the most similar words by cosine similarity.") word_input = gr.Textbox(label="Enter a word", placeholder="king") nn_btn = gr.Button("Search", variant="primary") nn_text = gr.Textbox(label="Results", interactive=False, lines=10) nn_plot = gr.Plot(label="Similarity") nn_btn.click(find_neighbors, word_input, [nn_text, nn_plot]) demo.launch()