Ram-N commited on
Commit
6f37f82
·
1 Parent(s): 8e7213f

Switch visualization from matplotlib to plotly

Browse files
Files changed (2) hide show
  1. app.py +93 -63
  2. requirements.txt +1 -1
app.py CHANGED
@@ -5,8 +5,8 @@ from sklearn.decomposition import PCA
5
  from sklearn.manifold import TSNE
6
  from sklearn.cluster import KMeans
7
  from sklearn.metrics import silhouette_score
8
- import matplotlib.pyplot as plt
9
- import matplotlib.cm as cm
10
  import gensim.downloader as api
11
 
12
  # Load both models at startup
@@ -103,7 +103,6 @@ WORD_PRESETS = {
103
  "doctor": "doctor",
104
  "dharma": "dharma",
105
  "cricket": "cricket",
106
- "Custom": None,
107
  }
108
 
109
  ANALOGY_PRESETS = {
@@ -111,7 +110,6 @@ ANALOGY_PRESETS = {
111
  "Gender (man : king :: woman : ?)": ("man", "king", "woman"),
112
  "Institutions (school : teacher :: hospital : ?)": ("school", "teacher","hospital"),
113
  "Nature (day : sun :: night : ?)": ("day", "sun", "night"),
114
- "Custom": None,
115
  }
116
 
117
  SENTENCE_PRESETS = {
@@ -119,7 +117,6 @@ SENTENCE_PRESETS = {
119
  "Negation — tricky for GloVe": ("I am happy", "I am not happy"),
120
  "Same word, different meaning": ("The bat flew at night", "He swung the bat"),
121
  "Unrelated": ("The train is very fast", "My grandmother makes great chai"),
122
- "Custom": None,
123
  }
124
 
125
 
@@ -170,28 +167,12 @@ def _auto_cluster(vecs):
170
  def _assign_colors(groups):
171
  """Map unique group names to distinct colours."""
172
  unique = list(dict.fromkeys(groups)) # preserve order, deduplicate
173
- palette = cm.get_cmap("tab10", len(unique))
174
- color_map = {g: palette(i) for i, g in enumerate(unique)}
175
- return [color_map[g] for g in groups], color_map
176
-
177
-
178
- def _scatter(ax, coords, words, colors, title):
179
- ax.scatter(coords[:, 0], coords[:, 1], c=colors, s=80, zorder=2)
180
- for i, word in enumerate(words):
181
- ax.annotate(word, (coords[i, 0], coords[i, 1]),
182
- textcoords="offset points", xytext=(5, 5), fontsize=9)
183
- ax.set_title(title)
184
- ax.axhline(0, color="lightgrey", linewidth=0.5)
185
- ax.axvline(0, color="lightgrey", linewidth=0.5)
186
-
187
-
188
- def _make_legend(ax, color_map):
189
- handles = [
190
- plt.Line2D([0], [0], marker="o", color="w",
191
- markerfacecolor=color, markersize=8, label=group)
192
- for group, color in color_map.items()
193
  ]
194
- ax.legend(handles=handles, loc="best", fontsize=8)
 
195
 
196
 
197
  def visualize(words_text, model_choice, selected_set):
@@ -213,9 +194,11 @@ def visualize(words_text, model_choice, selected_set):
213
  skipped = [w for w in words if w not in glove]
214
 
215
  if len(valid) < 2:
216
- fig, ax = plt.subplots()
217
- ax.text(0.5, 0.5, "Not enough words found in GloVe vocabulary.\nTry switching to Sentence Transformers.",
218
- ha="center", va="center", transform=ax.transAxes)
 
 
219
  return fig
220
 
221
  words = [v[0] for v in valid]
@@ -233,21 +216,48 @@ def visualize(words_text, model_choice, selected_set):
233
  # Colours — auto-cluster custom words; use predefined groups for presets
234
  if groups is None:
235
  groups = _auto_cluster(np.array(vecs))
236
- colors, color_map = _assign_colors(groups)
237
-
238
- # Plot
239
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))
240
- _scatter(ax1, pca_2d, words, colors, title=f"PCA ({model_choice})")
241
- _scatter(ax2, tsne_2d, words, colors, title=f"t-SNE ({model_choice})")
242
-
243
- if color_map:
244
- _make_legend(ax2, color_map)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  if skipped:
247
- fig.text(0.5, 0.01, f"Skipped (not in GloVe): {', '.join(skipped)}",
248
- ha="center", fontsize=8, color="grey")
 
 
 
249
 
250
- fig.tight_layout()
251
  return fig
252
 
253
 
@@ -271,18 +281,25 @@ with gr.Blocks(title="Embedding Playground") as demo:
271
  # --- TAB 1 ---
272
  with gr.Tab("Word Explorer"):
273
  gr.Markdown("Enter a word to see its vector and closest neighbours.")
 
274
  word_preset = gr.Dropdown(
275
  choices=list(WORD_PRESETS.keys()), value="tiger",
276
- label="Try a preset word, or choose Custom to type your own"
277
  )
278
- word_input = gr.Textbox(label="Enter a word", value="tiger")
279
  vec_output = gr.Textbox(label="Vector (first 10 dims)")
280
  neighbors_output = gr.Textbox(label="Closest words")
281
  btn = gr.Button("Explore", variant="primary")
282
 
283
  def fill_word(preset):
284
- return WORD_PRESETS[preset] if preset != "Custom" else gr.update()
 
 
 
 
 
285
 
 
286
  word_preset.change(fill_word, inputs=word_preset, outputs=word_input)
287
  btn.click(word_explorer, inputs=[word_input, model_choice],
288
  outputs=[vec_output, neighbors_output])
@@ -316,25 +333,30 @@ end up near each other in vector space.
316
  with gr.Tab("Analogies"):
317
  gr.Markdown("### A is to B as C is to ?")
318
  gr.Markdown("Vector arithmetic: **B − A + C** → find the closest word")
 
319
  analogy_preset = gr.Dropdown(
320
  choices=list(ANALOGY_PRESETS.keys()),
321
  value=list(ANALOGY_PRESETS.keys())[0],
322
- label="Try a preset, or choose Custom to type your own"
323
  )
324
  _default_a, _default_b, _default_c = ANALOGY_PRESETS[list(ANALOGY_PRESETS.keys())[0]]
325
- a_in = gr.Textbox(label="A (starting point)", value=_default_a)
326
- b_in = gr.Textbox(label="B (related to A)", value=_default_b)
327
- c_in = gr.Textbox(label="C (new starting point)", value=_default_c)
328
  analogy_vec = gr.Textbox(label="Result vector (first 10 dims)")
329
  analogy_out = gr.Textbox(label="Closest words")
330
  btn2 = gr.Button("Solve Analogy", variant="primary")
331
 
332
  def fill_analogy(preset):
333
- if preset == "Custom":
334
- return gr.update(), gr.update(), gr.update()
335
  a, b, c = ANALOGY_PRESETS[preset]
336
  return a, b, c
337
 
 
 
 
 
 
 
338
  analogy_preset.change(fill_analogy, inputs=analogy_preset, outputs=[a_in, b_in, c_in])
339
  btn2.click(analogy, inputs=[a_in, b_in, c_in, model_choice],
340
  outputs=[analogy_vec, analogy_out])
@@ -342,23 +364,28 @@ end up near each other in vector space.
342
  # --- TAB 3 ---
343
  with gr.Tab("Sentence Similarity"):
344
  gr.Markdown("Compare two sentences. Score ranges from 0 (unrelated) to 1 (identical meaning).")
 
345
  sent_preset = gr.Dropdown(
346
  choices=list(SENTENCE_PRESETS.keys()),
347
  value="Similar meaning",
348
- label="Try a preset pair, or choose Custom to type your own"
349
  )
350
  _default_s1, _default_s2 = SENTENCE_PRESETS["Similar meaning"]
351
- s1 = gr.Textbox(label="Sentence 1", value=_default_s1)
352
- s2 = gr.Textbox(label="Sentence 2", value=_default_s2)
353
  sim_output = gr.Textbox(label="Similarity")
354
  btn3 = gr.Button("Compare", variant="primary")
355
 
356
  def fill_sentences(preset):
357
- if preset == "Custom":
358
- return gr.update(), gr.update()
359
  s1v, s2v = SENTENCE_PRESETS[preset]
360
  return s1v, s2v
361
 
 
 
 
 
 
 
362
  sent_preset.change(fill_sentences, inputs=sent_preset, outputs=[s1, s2])
363
  btn3.click(sentence_similarity, inputs=[s1, s2, model_choice],
364
  outputs=sim_output)
@@ -370,23 +397,26 @@ end up near each other in vector space.
370
  "**PCA**: distances between clusters are meaningful. "
371
  "**t-SNE**: clusters are visually clearer, but distances *between* clusters are not meaningful."
372
  )
 
373
  set_dropdown = gr.Dropdown(
374
  choices=list(WORD_SETS.keys()),
375
  value="Semantic clusters",
376
- label="Word set — select one option from the dropdown, then press Plot"
377
  )
378
  custom_words = gr.Textbox(
379
- label="Custom words (comma separated — only used when 'Custom' is selected above)",
380
- placeholder="e.g. moon, star, sun, cloud, rain"
 
381
  )
382
  btn4 = gr.Button("Plot", variant="primary")
383
  plot_output = gr.Plot()
384
 
385
- set_dropdown.change(
386
- fn=lambda s: "" if s != "Custom" else gr.update(),
387
- inputs=set_dropdown,
388
- outputs=custom_words
389
- )
 
390
 
391
  btn4.click(visualize, inputs=[custom_words, model_choice, set_dropdown],
392
  outputs=plot_output)
 
5
  from sklearn.manifold import TSNE
6
  from sklearn.cluster import KMeans
7
  from sklearn.metrics import silhouette_score
8
+ import plotly.graph_objects as go
9
+ from plotly.subplots import make_subplots
10
  import gensim.downloader as api
11
 
12
  # Load both models at startup
 
103
  "doctor": "doctor",
104
  "dharma": "dharma",
105
  "cricket": "cricket",
 
106
  }
107
 
108
  ANALOGY_PRESETS = {
 
110
  "Gender (man : king :: woman : ?)": ("man", "king", "woman"),
111
  "Institutions (school : teacher :: hospital : ?)": ("school", "teacher","hospital"),
112
  "Nature (day : sun :: night : ?)": ("day", "sun", "night"),
 
113
  }
114
 
115
  SENTENCE_PRESETS = {
 
117
  "Negation — tricky for GloVe": ("I am happy", "I am not happy"),
118
  "Same word, different meaning": ("The bat flew at night", "He swung the bat"),
119
  "Unrelated": ("The train is very fast", "My grandmother makes great chai"),
 
120
  }
121
 
122
 
 
167
  def _assign_colors(groups):
168
  """Map unique group names to distinct colours."""
169
  unique = list(dict.fromkeys(groups)) # preserve order, deduplicate
170
+ palette = [
171
+ "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
172
+ "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  ]
174
+ color_map = {g: palette[i % len(palette)] for i, g in enumerate(unique)}
175
+ return [color_map[g] for g in groups], color_map
176
 
177
 
178
  def visualize(words_text, model_choice, selected_set):
 
194
  skipped = [w for w in words if w not in glove]
195
 
196
  if len(valid) < 2:
197
+ fig = go.Figure()
198
+ fig.add_annotation(
199
+ text="Not enough words found in GloVe vocabulary.<br>Try switching to Sentence Transformers.",
200
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14)
201
+ )
202
  return fig
203
 
204
  words = [v[0] for v in valid]
 
216
  # Colours — auto-cluster custom words; use predefined groups for presets
217
  if groups is None:
218
  groups = _auto_cluster(np.array(vecs))
219
+ _, color_map = _assign_colors(groups)
220
+
221
+ # Build Plotly subplots
222
+ fig = make_subplots(rows=1, cols=2,
223
+ subplot_titles=[f"PCA ({model_choice})", f"t-SNE ({model_choice})"])
224
+
225
+ unique_groups = list(dict.fromkeys(groups))
226
+ for grp in unique_groups:
227
+ indices = [j for j, g in enumerate(groups) if g == grp]
228
+ color = color_map[grp]
229
+
230
+ fig.add_trace(go.Scatter(
231
+ x=pca_2d[indices, 0], y=pca_2d[indices, 1],
232
+ mode="markers+text",
233
+ text=[words[j] for j in indices],
234
+ textposition="top center",
235
+ marker=dict(color=color, size=9),
236
+ name=grp,
237
+ legendgroup=grp,
238
+ showlegend=True,
239
+ ), row=1, col=1)
240
+
241
+ fig.add_trace(go.Scatter(
242
+ x=tsne_2d[indices, 0], y=tsne_2d[indices, 1],
243
+ mode="markers+text",
244
+ text=[words[j] for j in indices],
245
+ textposition="top center",
246
+ marker=dict(color=color, size=9),
247
+ name=grp,
248
+ legendgroup=grp,
249
+ showlegend=False,
250
+ ), row=1, col=2)
251
+
252
+ fig.update_layout(height=520, margin=dict(t=60, b=40))
253
 
254
  if skipped:
255
+ fig.add_annotation(
256
+ text=f"Skipped (not in GloVe): {', '.join(skipped)}",
257
+ xref="paper", yref="paper", x=0.5, y=-0.05,
258
+ showarrow=False, font=dict(size=10, color="grey")
259
+ )
260
 
 
261
  return fig
262
 
263
 
 
281
  # --- TAB 1 ---
282
  with gr.Tab("Word Explorer"):
283
  gr.Markdown("Enter a word to see its vector and closest neighbours.")
284
+ word_mode = gr.Radio(["Pre-set", "Custom"], value="Pre-set", label="Input mode")
285
  word_preset = gr.Dropdown(
286
  choices=list(WORD_PRESETS.keys()), value="tiger",
287
+ label="Preset word"
288
  )
289
+ word_input = gr.Textbox(label="Custom word", value="tiger", interactive=False)
290
  vec_output = gr.Textbox(label="Vector (first 10 dims)")
291
  neighbors_output = gr.Textbox(label="Closest words")
292
  btn = gr.Button("Explore", variant="primary")
293
 
294
  def fill_word(preset):
295
+ return WORD_PRESETS[preset]
296
+
297
+ def toggle_word_mode(mode):
298
+ if mode == "Pre-set":
299
+ return gr.update(interactive=True), gr.update(interactive=False)
300
+ return gr.update(interactive=False), gr.update(interactive=True)
301
 
302
+ word_mode.change(toggle_word_mode, inputs=word_mode, outputs=[word_preset, word_input])
303
  word_preset.change(fill_word, inputs=word_preset, outputs=word_input)
304
  btn.click(word_explorer, inputs=[word_input, model_choice],
305
  outputs=[vec_output, neighbors_output])
 
333
  with gr.Tab("Analogies"):
334
  gr.Markdown("### A is to B as C is to ?")
335
  gr.Markdown("Vector arithmetic: **B − A + C** → find the closest word")
336
+ analogy_mode = gr.Radio(["Pre-set", "Custom"], value="Pre-set", label="Input mode")
337
  analogy_preset = gr.Dropdown(
338
  choices=list(ANALOGY_PRESETS.keys()),
339
  value=list(ANALOGY_PRESETS.keys())[0],
340
+ label="Preset analogy"
341
  )
342
  _default_a, _default_b, _default_c = ANALOGY_PRESETS[list(ANALOGY_PRESETS.keys())[0]]
343
+ a_in = gr.Textbox(label="A (starting point)", value=_default_a, interactive=False)
344
+ b_in = gr.Textbox(label="B (related to A)", value=_default_b, interactive=False)
345
+ c_in = gr.Textbox(label="C (new starting point)", value=_default_c, interactive=False)
346
  analogy_vec = gr.Textbox(label="Result vector (first 10 dims)")
347
  analogy_out = gr.Textbox(label="Closest words")
348
  btn2 = gr.Button("Solve Analogy", variant="primary")
349
 
350
  def fill_analogy(preset):
 
 
351
  a, b, c = ANALOGY_PRESETS[preset]
352
  return a, b, c
353
 
354
+ def toggle_analogy_mode(mode):
355
+ if mode == "Pre-set":
356
+ return gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
357
+ return gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
358
+
359
+ analogy_mode.change(toggle_analogy_mode, inputs=analogy_mode, outputs=[analogy_preset, a_in, b_in, c_in])
360
  analogy_preset.change(fill_analogy, inputs=analogy_preset, outputs=[a_in, b_in, c_in])
361
  btn2.click(analogy, inputs=[a_in, b_in, c_in, model_choice],
362
  outputs=[analogy_vec, analogy_out])
 
364
  # --- TAB 3 ---
365
  with gr.Tab("Sentence Similarity"):
366
  gr.Markdown("Compare two sentences. Score ranges from 0 (unrelated) to 1 (identical meaning).")
367
+ sent_mode = gr.Radio(["Pre-set", "Custom"], value="Pre-set", label="Input mode")
368
  sent_preset = gr.Dropdown(
369
  choices=list(SENTENCE_PRESETS.keys()),
370
  value="Similar meaning",
371
+ label="Preset sentence pair"
372
  )
373
  _default_s1, _default_s2 = SENTENCE_PRESETS["Similar meaning"]
374
+ s1 = gr.Textbox(label="Sentence 1", value=_default_s1, interactive=False)
375
+ s2 = gr.Textbox(label="Sentence 2", value=_default_s2, interactive=False)
376
  sim_output = gr.Textbox(label="Similarity")
377
  btn3 = gr.Button("Compare", variant="primary")
378
 
379
  def fill_sentences(preset):
 
 
380
  s1v, s2v = SENTENCE_PRESETS[preset]
381
  return s1v, s2v
382
 
383
+ def toggle_sent_mode(mode):
384
+ if mode == "Pre-set":
385
+ return gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=False)
386
+ return gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=True)
387
+
388
+ sent_mode.change(toggle_sent_mode, inputs=sent_mode, outputs=[sent_preset, s1, s2])
389
  sent_preset.change(fill_sentences, inputs=sent_preset, outputs=[s1, s2])
390
  btn3.click(sentence_similarity, inputs=[s1, s2, model_choice],
391
  outputs=sim_output)
 
397
  "**PCA**: distances between clusters are meaningful. "
398
  "**t-SNE**: clusters are visually clearer, but distances *between* clusters are not meaningful."
399
  )
400
+ viz_mode = gr.Radio(["Pre-set", "Custom"], value="Pre-set", label="Input mode")
401
  set_dropdown = gr.Dropdown(
402
  choices=list(WORD_SETS.keys()),
403
  value="Semantic clusters",
404
+ label="Word set"
405
  )
406
  custom_words = gr.Textbox(
407
+ label="Custom words (comma separated)",
408
+ placeholder="e.g. moon, star, sun, cloud, rain",
409
+ interactive=False
410
  )
411
  btn4 = gr.Button("Plot", variant="primary")
412
  plot_output = gr.Plot()
413
 
414
+ def toggle_viz_mode(mode):
415
+ if mode == "Pre-set":
416
+ return gr.update(interactive=True), gr.update(interactive=False)
417
+ return gr.update(interactive=False, value="Custom"), gr.update(interactive=True)
418
+
419
+ viz_mode.change(toggle_viz_mode, inputs=viz_mode, outputs=[set_dropdown, custom_words])
420
 
421
  btn4.click(visualize, inputs=[custom_words, model_choice, set_dropdown],
422
  outputs=plot_output)
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  gradio
2
  sentence-transformers
3
  scikit-learn
4
- matplotlib
5
  gensim
 
1
  gradio
2
  sentence-transformers
3
  scikit-learn
4
+ plotly
5
  gensim