microembeddings / microembeddings.py
shreyask's picture
fix: robust text8 loading, gensim attribution in UI, training error handling
43fd8a7 verified
"""
microembeddings.py — Word2Vec skip-gram with negative sampling from scratch.
~180 lines, NumPy only. Inspired by Karpathy's microGPT.
"""
import numpy as np
import os
import urllib.request
import zipfile
from collections import Counter
# --- Hyperparameters ---
EMBED_DIM = 50
WINDOW_SIZE = 5
NUM_NEGATIVES = 5
LEARNING_RATE = 0.025
MIN_LR = 0.0001
EPOCHS = 3
MIN_COUNT = 5
MAX_VOCAB = 10000
SUBSAMPLE_THRESHOLD = 1e-4
TEXT8_FILE = "text8"
TEXT8_ZIP = "text8.zip"
TEXT8_URL = "http://mattmahoney.net/dc/text8.zip"
def describe_text8_source():
"""Summarize how training data will be loaded."""
if os.path.exists(TEXT8_FILE):
return "Local text8 corpus found."
if os.path.exists(TEXT8_ZIP):
return "Local text8.zip found; it will be extracted on first train."
return "text8 is not bundled; Train will download it on first run."
def load_text8(max_words=500000):
"""Download text8 and return list of words."""
downloaded = False
if not os.path.exists(TEXT8_FILE):
if not os.path.exists(TEXT8_ZIP):
print("Downloading text8...")
try:
urllib.request.urlretrieve(TEXT8_URL, TEXT8_ZIP)
except OSError as exc:
raise RuntimeError(
"Could not load text8. Add a local text8/text8.zip file or allow outbound download."
) from exc
downloaded = True
try:
with zipfile.ZipFile(TEXT8_ZIP) as z:
z.extractall()
except (OSError, zipfile.BadZipFile) as exc:
raise RuntimeError("text8.zip is missing or invalid.") from exc
if downloaded:
os.remove(TEXT8_ZIP)
with open(TEXT8_FILE) as f:
words = f.read().split()[:max_words]
print(f"Loaded {len(words)} words")
return words
def build_vocab(words, min_count=MIN_COUNT, max_vocab=MAX_VOCAB):
"""Build word-to-index mapping from most frequent words."""
counts = Counter(words)
vocab_words = [w for w, c in counts.most_common(max_vocab) if c >= min_count]
word2idx = {w: i for i, w in enumerate(vocab_words)}
idx2word = {i: w for w, i in word2idx.items()}
freqs = np.array([counts[idx2word[i]] for i in range(len(vocab_words))], dtype=np.float64)
print(f"Vocabulary: {len(vocab_words)} words")
return word2idx, idx2word, freqs
def prepare_corpus(words, word2idx, freqs):
"""Filter to vocab words and apply subsampling of frequent words."""
total = freqs.sum()
probs = 1 - np.sqrt(SUBSAMPLE_THRESHOLD * total / freqs)
probs = np.clip(probs, 0, 1)
corpus = []
for w in words:
if w not in word2idx:
continue
idx = word2idx[w]
if np.random.random() < probs[idx]:
continue
corpus.append(idx)
print(f"Training corpus: {len(corpus)} tokens (after subsampling)")
return np.array(corpus)
def build_neg_table(freqs, table_size=100_000_000):
"""Build unigram^0.75 distribution for negative sampling."""
power_freqs = freqs ** 0.75
power_freqs /= power_freqs.sum()
return power_freqs
def train(corpus, vocab_size, neg_dist, epochs=EPOCHS, embed_dim=EMBED_DIM,
lr=LEARNING_RATE, window=WINDOW_SIZE, num_neg=NUM_NEGATIVES,
callback=None):
"""Train skip-gram with negative sampling. Returns embedding matrix W."""
# Init scale ~1/sqrt(dim) gives vectors enough room to differentiate
scale = 0.5 / embed_dim ** 0.5
W = (np.random.randn(vocab_size, embed_dim) * scale).astype(np.float32)
C = np.zeros((vocab_size, embed_dim), dtype=np.float32)
# Each position draws actual_window ~ uniform(1..window), generating
# 2*actual_window context pairs. Expected pairs = 2 * E[uniform(1..w)] = w+1
total_steps = epochs * len(corpus) * (window + 1)
step = 0
losses = []
for epoch in range(epochs):
epoch_loss = 0.0
count = 0
for i in range(window, len(corpus) - window):
center = corpus[i]
actual_window = np.random.randint(1, window + 1)
context_indices = []
for j in range(i - actual_window, i + actual_window + 1):
if j != i:
context_indices.append(corpus[j])
for ctx in context_indices:
# Positive pair
dot = np.dot(W[center], C[ctx])
sig = 1.0 / (1.0 + np.exp(-np.clip(dot, -6, 6)))
grad_pos = (sig - 1.0)
# Negative samples
neg_samples = np.random.choice(vocab_size, size=num_neg, p=neg_dist)
neg_dots = C[neg_samples] @ W[center]
neg_sigs = 1.0 / (1.0 + np.exp(-np.clip(neg_dots, -6, 6)))
# Loss
loss = -np.log(sig + 1e-10) - np.sum(np.log(1 - neg_sigs + 1e-10))
epoch_loss += loss
count += 1
# Gradient updates (use old W[center] for C updates)
alpha = max(lr * (1 - step / total_steps), MIN_LR)
w_old = W[center].copy()
grad_w = grad_pos * C[ctx] + (neg_sigs[:, None] * C[neg_samples]).sum(axis=0)
W[center] -= alpha * grad_w
C[ctx] -= alpha * grad_pos * w_old
C[neg_samples] -= alpha * neg_sigs[:, None] * w_old
step += 1
if callback and i % 10000 == 0:
avg_loss = epoch_loss / max(count, 1)
callback(epoch, i, len(corpus), avg_loss)
losses.append(avg_loss)
avg_loss = epoch_loss / max(count, 1)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
if callback:
losses.append(avg_loss)
return W, losses
def normalize(W):
"""L2-normalize each row."""
norms = np.linalg.norm(W, axis=1, keepdims=True)
return W / np.maximum(norms, 1e-10)
def most_similar(word, W_norm, word2idx, idx2word, topn=10):
"""Find most similar words by cosine similarity."""
if word not in word2idx:
return []
vec = W_norm[word2idx[word]]
sims = W_norm @ vec
top_indices = np.argsort(-sims)[1:topn+1]
return [(idx2word[i], float(sims[i])) for i in top_indices]
def analogy(a, b, c, W_norm, word2idx, idx2word, topn=5):
"""Solve: a is to b as c is to ? (b - a + c)"""
for w in [a, b, c]:
if w not in word2idx:
return []
vec = W_norm[word2idx[b]] - W_norm[word2idx[a]] + W_norm[word2idx[c]]
vec = vec / (np.linalg.norm(vec) + 1e-10)
sims = W_norm @ vec
exclude = {word2idx[a], word2idx[b], word2idx[c]}
results = []
for i in np.argsort(-sims):
if int(i) not in exclude:
results.append((idx2word[int(i)], float(sims[i])))
if len(results) == topn:
break
return results
if __name__ == "__main__":
words = load_text8()
word2idx, idx2word, freqs = build_vocab(words)
corpus = prepare_corpus(words, word2idx, freqs)
neg_dist = build_neg_table(freqs)
def progress(epoch, i, total, loss):
pct = i / total * 100
print(f" Epoch {epoch+1}: {pct:.0f}% complete, loss={loss:.4f}", end="\r")
W, losses = train(corpus, len(word2idx), neg_dist, callback=progress)
W_norm = normalize(W)
print("\n\n--- Nearest Neighbors ---")
for word in ["king", "computer", "france", "dog"]:
neighbors = most_similar(word, W_norm, word2idx, idx2word)
print(f"\n{word}: {', '.join(f'{w} ({s:.3f})' for w, s in neighbors[:5])}")
print("\n--- Analogies ---")
for a, b, c in [("man", "king", "woman"), ("france", "paris", "germany"),
("big", "bigger", "small")]:
results = analogy(a, b, c, W_norm, word2idx, idx2word)
ans = results[0][0] if results else "?"
print(f"{a} : {b} :: {c} : {ans}")