shreyask commited on
Commit
43fd8a7
·
verified ·
1 Parent(s): 451cbf8

fix: robust text8 loading, gensim attribution in UI, training error handling

Browse files
Files changed (2) hide show
  1. app.py +44 -28
  2. microembeddings.py +31 -9
app.py CHANGED
@@ -6,7 +6,7 @@ from sklearn.decomposition import PCA
6
  from sklearn.manifold import TSNE
7
  from microembeddings import (
8
  load_text8, build_vocab, prepare_corpus, build_neg_table,
9
- train, normalize, most_similar, analogy
10
  )
11
 
12
  # --- Global state ---
@@ -17,49 +17,57 @@ def load_pretrained():
17
  """Load pre-trained embeddings if available."""
18
  try:
19
  W = np.load("pretrained_W.npy")
20
- meta = json.load(open("pretrained_vocab.json"))
 
21
  vocab = meta["vocab"]
22
  state["W"] = W
23
  state["W_norm"] = normalize(W)
24
  state["word2idx"] = {w: i for i, w in enumerate(vocab)}
25
  state["idx2word"] = {i: w for i, w in enumerate(vocab)}
26
  state["losses"] = meta.get("losses", [])
27
- return f"Loaded pre-trained: {W.shape[0]} words x {W.shape[1]} dims"
 
 
 
28
  except FileNotFoundError:
29
  return "No pre-trained embeddings found. Train from scratch!"
30
 
31
 
32
  # --- Tab 1: Train ---
33
  def run_training(embed_dim, window_size, learning_rate, num_neg, progress=gr.Progress()):
34
- progress(0, desc="Loading text8...")
35
- words = load_text8(500000)
36
- word2idx, idx2word, freqs = build_vocab(words)
37
- corpus = prepare_corpus(words, word2idx, freqs)
38
- neg_dist = build_neg_table(freqs)
 
 
39
 
40
- state["word2idx"] = word2idx
41
- state["idx2word"] = idx2word
42
- losses = []
43
 
44
- def callback(epoch, i, total, loss):
45
- pct = i / total
46
- progress(pct, desc=f"Epoch {epoch+1}: loss={loss:.4f}")
47
- losses.append(loss)
48
 
49
- W, _ = train(corpus, len(word2idx), neg_dist,
50
- epochs=3, embed_dim=int(embed_dim), lr=learning_rate,
51
- window=int(window_size), num_neg=int(num_neg), callback=callback)
52
 
53
- state["W"] = W
54
- state["W_norm"] = normalize(W)
55
- state["losses"] = losses
56
 
57
- fig = go.Figure()
58
- fig.add_trace(go.Scatter(y=losses, mode="lines", name="Loss",
59
- line=dict(color="#4F46E5")))
60
- fig.update_layout(title="Training Loss", xaxis_title="Step", yaxis_title="Loss",
61
- template="plotly_white")
62
- return fig, f"Done! {W.shape[0]} words x {W.shape[1]} dims"
 
 
63
 
64
 
65
  # --- Tab 2: Explore ---
@@ -153,6 +161,7 @@ def find_neighbors(word):
153
 
154
  # --- Build UI ---
155
  load_msg = load_pretrained()
 
156
 
157
  with gr.Blocks(title="microembeddings", theme=gr.themes.Soft()) as demo:
158
  gr.Markdown(
@@ -163,10 +172,17 @@ with gr.Blocks(title="microembeddings", theme=gr.themes.Soft()) as demo:
163
  "(https://kshreyas.dev/post/microembeddings/)"
164
  )
165
  gr.Markdown(f"*{load_msg}*")
 
 
 
 
166
 
167
  with gr.Tabs():
168
  with gr.Tab("Train"):
169
- gr.Markdown("Train word embeddings from scratch on text8 (cleaned Wikipedia).")
 
 
 
170
  with gr.Row():
171
  dim_slider = gr.Slider(25, 100, value=50, step=25,
172
  label="Embedding dimension")
 
6
  from sklearn.manifold import TSNE
7
  from microembeddings import (
8
  load_text8, build_vocab, prepare_corpus, build_neg_table,
9
+ train, normalize, most_similar, analogy, describe_text8_source
10
  )
11
 
12
  # --- Global state ---
 
17
  """Load pre-trained embeddings if available."""
18
  try:
19
  W = np.load("pretrained_W.npy")
20
+ with open("pretrained_vocab.json") as f:
21
+ meta = json.load(f)
22
  vocab = meta["vocab"]
23
  state["W"] = W
24
  state["W_norm"] = normalize(W)
25
  state["word2idx"] = {w: i for i, w in enumerate(vocab)}
26
  state["idx2word"] = {i: w for i, w in enumerate(vocab)}
27
  state["losses"] = meta.get("losses", [])
28
+ return (
29
+ "Loaded pre-trained full-text8 gensim vectors: "
30
+ f"{W.shape[0]} words x {W.shape[1]} dims"
31
+ )
32
  except FileNotFoundError:
33
  return "No pre-trained embeddings found. Train from scratch!"
34
 
35
 
36
  # --- Tab 1: Train ---
37
  def run_training(embed_dim, window_size, learning_rate, num_neg, progress=gr.Progress()):
38
+ fig = go.Figure()
39
+ try:
40
+ progress(0, desc="Loading text8...")
41
+ words = load_text8(500000)
42
+ word2idx, idx2word, freqs = build_vocab(words)
43
+ corpus = prepare_corpus(words, word2idx, freqs)
44
+ neg_dist = build_neg_table(freqs)
45
 
46
+ state["word2idx"] = word2idx
47
+ state["idx2word"] = idx2word
48
+ losses = []
49
 
50
+ def callback(epoch, i, total, loss):
51
+ pct = i / total
52
+ progress(pct, desc=f"Epoch {epoch+1}: loss={loss:.4f}")
53
+ losses.append(loss)
54
 
55
+ W, _ = train(corpus, len(word2idx), neg_dist,
56
+ epochs=3, embed_dim=int(embed_dim), lr=learning_rate,
57
+ window=int(window_size), num_neg=int(num_neg), callback=callback)
58
 
59
+ state["W"] = W
60
+ state["W_norm"] = normalize(W)
61
+ state["losses"] = losses
62
 
63
+ fig.add_trace(go.Scatter(y=losses, mode="lines", name="Loss",
64
+ line=dict(color="#4F46E5")))
65
+ fig.update_layout(title="Training Loss", xaxis_title="Step", yaxis_title="Loss",
66
+ template="plotly_white")
67
+ return fig, f"Done! {W.shape[0]} words x {W.shape[1]} dims"
68
+ except Exception as exc:
69
+ fig.update_layout(title="Training unavailable", template="plotly_white")
70
+ return fig, f"Training failed: {exc}"
71
 
72
 
73
  # --- Tab 2: Explore ---
 
161
 
162
  # --- Build UI ---
163
  load_msg = load_pretrained()
164
+ corpus_msg = describe_text8_source()
165
 
166
  with gr.Blocks(title="microembeddings", theme=gr.themes.Soft()) as demo:
167
  gr.Markdown(
 
172
  "(https://kshreyas.dev/post/microembeddings/)"
173
  )
174
  gr.Markdown(f"*{load_msg}*")
175
+ gr.Markdown(
176
+ "*Preloaded vectors use gensim Word2Vec on the full 17M-word text8 corpus.* "
177
+ "*The Train tab reruns the NumPy implementation on a 500k-word subset so it stays interactive.*"
178
+ )
179
 
180
  with gr.Tabs():
181
  with gr.Tab("Train"):
182
+ gr.Markdown(
183
+ "Train word embeddings from scratch on text8 (cleaned Wikipedia).\n\n"
184
+ f"{corpus_msg}"
185
+ )
186
  with gr.Row():
187
  dim_slider = gr.Slider(25, 100, value=50, step=25,
188
  label="Embedding dimension")
microembeddings.py CHANGED
@@ -18,19 +18,41 @@ EPOCHS = 3
18
  MIN_COUNT = 5
19
  MAX_VOCAB = 10000
20
  SUBSAMPLE_THRESHOLD = 1e-4
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def load_text8(max_words=500000):
24
  """Download text8 and return list of words."""
25
- fname = "text8"
26
- if not os.path.exists(fname):
27
- url = "http://mattmahoney.net/dc/text8.zip"
28
- print("Downloading text8...")
29
- urllib.request.urlretrieve(url, "text8.zip")
30
- with zipfile.ZipFile("text8.zip") as z:
31
- z.extractall()
32
- os.remove("text8.zip")
33
- with open(fname) as f:
 
 
 
 
 
 
 
 
 
 
34
  words = f.read().split()[:max_words]
35
  print(f"Loaded {len(words)} words")
36
  return words
 
18
  MIN_COUNT = 5
19
  MAX_VOCAB = 10000
20
  SUBSAMPLE_THRESHOLD = 1e-4
21
+ TEXT8_FILE = "text8"
22
+ TEXT8_ZIP = "text8.zip"
23
+ TEXT8_URL = "http://mattmahoney.net/dc/text8.zip"
24
+
25
+
26
+ def describe_text8_source():
27
+ """Summarize how training data will be loaded."""
28
+ if os.path.exists(TEXT8_FILE):
29
+ return "Local text8 corpus found."
30
+ if os.path.exists(TEXT8_ZIP):
31
+ return "Local text8.zip found; it will be extracted on first train."
32
+ return "text8 is not bundled; Train will download it on first run."
33
 
34
 
35
  def load_text8(max_words=500000):
36
  """Download text8 and return list of words."""
37
+ downloaded = False
38
+ if not os.path.exists(TEXT8_FILE):
39
+ if not os.path.exists(TEXT8_ZIP):
40
+ print("Downloading text8...")
41
+ try:
42
+ urllib.request.urlretrieve(TEXT8_URL, TEXT8_ZIP)
43
+ except OSError as exc:
44
+ raise RuntimeError(
45
+ "Could not load text8. Add a local text8/text8.zip file or allow outbound download."
46
+ ) from exc
47
+ downloaded = True
48
+ try:
49
+ with zipfile.ZipFile(TEXT8_ZIP) as z:
50
+ z.extractall()
51
+ except (OSError, zipfile.BadZipFile) as exc:
52
+ raise RuntimeError("text8.zip is missing or invalid.") from exc
53
+ if downloaded:
54
+ os.remove(TEXT8_ZIP)
55
+ with open(TEXT8_FILE) as f:
56
  words = f.read().split()[:max_words]
57
  print(f"Loaded {len(words)} words")
58
  return words