|
|
import gradio as gr |
|
|
import re |
|
|
from functools import lru_cache |
|
|
import gensim.downloader as api |
|
|
from gensim.models import KeyedVectors |
|
|
import pandas as pd |
|
|
|
|
|
MODEL_OPTIONS = { |
|
|
"glove-wiki-gigaword-50": "50d GloVe (Wikipedia+Gigaword) β small & fast", |
|
|
"glove-wiki-gigaword-100": "100d GloVe (Wikipedia+Gigaword) β balanced", |
|
|
"glove-wiki-gigaword-200": "200d GloVe (Wikipedia+Gigaword)", |
|
|
"glove-wiki-gigaword-300": "300d GloVe (Wikipedia+Gigaword)", |
|
|
"word2vec-google-news-300": "300d Google News Word2Vec β large (~1.6GB)" |
|
|
} |
|
|
|
|
|
TOKEN_RE = re.compile(r"[+\-]|[^+\-\s]+") |
|
|
|
|
|
@lru_cache(maxsize=4) |
|
|
def get_model(name: str) -> KeyedVectors: |
|
|
"""Load/download a pre-trained embedding with caching.""" |
|
|
return api.load(name) |
|
|
|
|
|
def parse_expression(expr: str): |
|
|
tokens = TOKEN_RE.findall(expr.strip()) |
|
|
if not tokens: |
|
|
return [], [] |
|
|
pos, neg, sign = [], [], '+' |
|
|
for tok in tokens: |
|
|
tok = tok.strip() |
|
|
if tok in ['+', '-']: |
|
|
sign = tok |
|
|
continue |
|
|
(pos if sign == '+' else neg).append(tok) |
|
|
return pos, neg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_expression(model_name: str, expr: str, topn: int, exclude_inputs: bool): |
|
|
try: |
|
|
model = get_model(model_name) |
|
|
except Exception as e: |
|
|
return None, f"β Failed to load model '{model_name}': {e}" |
|
|
|
|
|
pos, neg = parse_expression(expr or "") |
|
|
if not pos and not neg: |
|
|
return None, "β οΈ Please enter at least one word." |
|
|
|
|
|
pos_in = [w for w in pos if w in model.key_to_index] |
|
|
neg_in = [w for w in neg if w in model.key_to_index] |
|
|
oov = [w for w in pos + neg if w not in model.key_to_index] |
|
|
|
|
|
if not pos_in and not neg_in: |
|
|
return None, "β All words are out-of-vocabulary for this model. Try different words or a different model." |
|
|
|
|
|
try: |
|
|
results = model.most_similar(positive=pos_in, negative=neg_in, topn=topn + len(pos_in) + len(neg_in)) |
|
|
except Exception as e: |
|
|
return None, f"β Computation error: {e}" |
|
|
|
|
|
if exclude_inputs: |
|
|
inputs = {w.lower() for w in pos_in + neg_in} |
|
|
results = [(w, s) for (w, s) in results if w.lower() not in inputs] |
|
|
|
|
|
results = results[:topn] |
|
|
df = pd.DataFrame(results, columns=["Word", "Cosine similarity"]) if results else None |
|
|
|
|
|
info_bits = [ |
|
|
f"**Model:** `{model_name}` (dim={model.vector_size})", |
|
|
f"**Positive:** {', '.join(pos_in) if pos_in else 'β'}", |
|
|
f"**Negative:** {', '.join(neg_in) if neg_in else 'β'}", |
|
|
] |
|
|
if oov: |
|
|
info_bits.append(f"**Out-of-vocabulary skipped:** {', '.join(oov)}") |
|
|
info = "\n\n".join(info_bits) |
|
|
return df, info |
|
|
|
|
|
|
|
|
def compute_abc(model_name: str, a: str, b: str, c: str, topn: int, exclude_inputs: bool): |
|
|
try: |
|
|
model = get_model(model_name) |
|
|
except Exception as e: |
|
|
return None, f"β Failed to load model '{model_name}': {e}" |
|
|
|
|
|
used, missing = [], [] |
|
|
vec = None |
|
|
for word, sign in [(a, +1), (b, +1), (c, -1)]: |
|
|
w = (word or '').strip() |
|
|
if not w: |
|
|
continue |
|
|
if w in model.key_to_index: |
|
|
used.append((w, sign)) |
|
|
v = model.get_vector(w) |
|
|
vec = (v if vec is None else vec + sign * v) |
|
|
else: |
|
|
missing.append(w) |
|
|
|
|
|
if vec is None: |
|
|
return None, "β No valid words to compute a vector." |
|
|
|
|
|
results = model.similar_by_vector(vec, topn=topn + len(used)) |
|
|
if exclude_inputs: |
|
|
inputs = {w for w, _ in used} |
|
|
results = [(w, s) for (w, s) in results if w not in inputs] |
|
|
results = results[:topn] |
|
|
|
|
|
df = pd.DataFrame(results, columns=["Word", "Cosine similarity"]) if results else None |
|
|
|
|
|
info_bits = [ |
|
|
f"**Model:** `{model_name}` (dim={model.vector_size})", |
|
|
f"**Used:** {', '.join([('+' if s>0 else 'β') + w for w,s in used]) if used else 'β'}", |
|
|
] |
|
|
if missing: |
|
|
info_bits.append(f"**Out-of-vocabulary skipped:** {', '.join(missing)}") |
|
|
info = "\n\n".join(info_bits) |
|
|
return df, info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Word Embeddings Playground β Gradio") as demo: |
|
|
gr.Markdown(""" |
|
|
# π§ Word Embeddings Playground |
|
|
Type equations like `king + woman - man` and explore nearest words using pre-trained Gensim embeddings. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
model_name = gr.Dropdown( |
|
|
choices=list(MODEL_OPTIONS.keys()), |
|
|
value="glove-wiki-gigaword-100", |
|
|
label="Model", |
|
|
info="If downloads stall, try a smaller model first (50d/100d)." |
|
|
) |
|
|
topn = gr.Slider(5, 50, value=10, step=1, label="Top N similar results") |
|
|
exclude_inputs = gr.Checkbox(value=True, label="Exclude input words from results") |
|
|
|
|
|
with gr.Tab("Expression: A + B β C + β¦"): |
|
|
expr = gr.Textbox(value="king + woman - man", label="Expression (use + and -)") |
|
|
compute_btn = gr.Button("Compute", variant="primary") |
|
|
out_df = gr.Dataframe(headers=["Word", "Cosine similarity"], interactive=False) |
|
|
out_info = gr.Markdown() |
|
|
|
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
["king + woman - man"], |
|
|
["paris - france + italy"], |
|
|
["walk + past - present"], |
|
|
["big - bigger + small"], |
|
|
["programmer + woman - man"], |
|
|
], |
|
|
inputs=[expr], |
|
|
label="Examples" |
|
|
) |
|
|
|
|
|
compute_btn.click( |
|
|
fn=compute_expression, |
|
|
inputs=[model_name, expr, topn, exclude_inputs], |
|
|
outputs=[out_df, out_info] |
|
|
) |
|
|
|
|
|
with gr.Tab("Advanced: A + B β C"): |
|
|
with gr.Row(): |
|
|
a = gr.Textbox(value="king", label="Word A (+)") |
|
|
b = gr.Textbox(value="woman", label="Word B (+)") |
|
|
c = gr.Textbox(value="man", label="Word C (β)") |
|
|
compute_btn2 = gr.Button("Compute A + B β C") |
|
|
out_df2 = gr.Dataframe(headers=["Word", "Cosine similarity"], interactive=False) |
|
|
out_info2 = gr.Markdown() |
|
|
|
|
|
compute_btn2.click( |
|
|
fn=compute_abc, |
|
|
inputs=[model_name, a, b, c, topn, exclude_inputs], |
|
|
outputs=[out_df2, out_info2] |
|
|
) |
|
|
|
|
|
gr.Markdown("Built with **Gradio** + **Gensim**. Models load via `gensim.downloader`; first-time downloads can take a while depending on size.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|