Spaces:
Running
Running
| # app.py | |
| import random | |
| import numpy as np | |
| import streamlit as st | |
| import plotly.graph_objects as go | |
| from sklearn.decomposition import PCA | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| st.set_page_config(page_title="Embedding Visualizer", layout="wide") | |
| # ----------------------------- | |
| # Base datasets (dataset names stay lowercase) | |
| # ----------------------------- | |
| BASE_SETS = { | |
| "countries": [ | |
| "Germany","France","Italy","Spain","Portugal","Poland","Netherlands","Belgium","Austria","Switzerland", | |
| "Greece","Norway","Sweden","Finland","Denmark","Ireland","Hungary","Czechia","Slovakia","Slovenia", | |
| "Romania","Bulgaria","Croatia","Estonia","Latvia" | |
| ], | |
| "animals": [ | |
| "cat","dog","lion","tiger","bear","wolf","fox","eagle","shark","whale", | |
| "zebra","giraffe","elephant","hippopotamus","rhinoceros","kangaroo","panda","otter","seal","dolphin", | |
| "chimpanzee","gorilla","leopard","cheetah","lynx" | |
| ], | |
| "furniture": [ | |
| "armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser", | |
| "nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner", | |
| "ottoman","console table","vanity","buffet","sectional sofa" | |
| ], | |
| "actors": [ | |
| "Brad Pitt","Angelina Jolie","Meryl Streep","Leonardo DiCaprio","Tom Hanks","Scarlett Johansson","Robert De Niro", | |
| "Natalie Portman","Matt Damon","Cate Blanchett","Johnny Depp","Keanu Reeves","Hugh Jackman","Emma Stone","Ryan Gosling", | |
| "Jennifer Lawrence","Christian Bale","Charlize Theron","Will Smith","Anne Hathaway","Denzel Washington","Morgan Freeman", | |
| "Julia Roberts","George Clooney","Kate Winslet" | |
| ], | |
| "rock groups": [ | |
| "The Beatles","Rolling Stones","Pink Floyd","Queen","Led Zeppelin","U2","AC/DC","Nirvana","Radiohead","Metallica", | |
| "Guns N' Roses","Red Hot Chili Peppers","Coldplay","Pearl Jam","The Police","Aerosmith","Green Day","Foo Fighters", | |
| "The Doors","Bon Jovi","Deep Purple","The Who","The Kinks","Fleetwood Mac","The Beach Boys" | |
| ], | |
| "sports": [ | |
| "soccer","basketball","tennis","baseball","golf","swimming","cycling","running","volleyball","rugby", | |
| "boxing","skiing","snowboarding","surfing","skateboarding","karate","judo","fencing","rowing","badminton", | |
| "cricket","table tennis","gymnastics","hockey","climbing" | |
| ], | |
| "ai_cs_concepts": [ | |
| # AI concepts (10) | |
| "neural network","transformer","embedding","fine-tuning","vector database", | |
| "retrieval-augmented generation","prompt","agent","inference","self-attention", | |
| # Classical CS concepts (10) | |
| "algorithm","data structure","compiler","register","stack", | |
| "queue","binary tree","hash table","database","quicksort" | |
| ], | |
| "tech_companies": [ | |
| # AI / Tech (7) | |
| "OpenAI","Anthropic","Google","Microsoft","Meta","NVIDIA","Hugging Face", | |
| # Automotive (7) | |
| "Tesla","BMW","Mercedes-Benz","Volkswagen","Toyota","Ford","Volvo", | |
| # Pharma / Life Science (6) | |
| "Pfizer","Roche","Novartis","Johnson & Johnson","Bayer","BioNTech" | |
| ], | |
| "finance": [ | |
| # Core finance terms (10) | |
| "equity","bond","derivative","liquidity","leverage", | |
| "portfolio","valuation","capital","revenue","profit", | |
| # Currencies (10) | |
| "US dollar","euro","British pound","Japanese yen","Swiss franc", | |
| "Chinese yuan","Canadian dollar","Australian dollar","Indian rupee","Brazilian real" | |
| ], | |
| } | |
| # ----------------------------- | |
| # Build datasets once per session (base + 3 random mixed) | |
| # ----------------------------- | |
| def make_random_mixed_sets(base: dict, n: int = 3) -> dict: | |
| keys = list(base.keys()) | |
| out = {} | |
| for _ in range(n): | |
| src = random.sample(keys, 3) | |
| items = [] | |
| for s in src: | |
| take = min(7, len(base[s])) | |
| items.extend(random.sample(base[s], take)) | |
| out["/".join(src)] = items[:21] | |
| return out | |
| if "datasets" not in st.session_state: | |
| mixed = make_random_mixed_sets(BASE_SETS, 3) | |
| st.session_state.datasets = {**BASE_SETS, **mixed} | |
| DATASETS = st.session_state.datasets # shorthand | |
| # ----------------------------- | |
| # Models (transformers) | |
| # ----------------------------- | |
| MODELS = { | |
| "all-MiniLM-L6-v2 (384d)": "sentence-transformers/all-MiniLM-L6-v2", | |
| "all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2", | |
| "all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1", | |
| } | |
| def load_model(model_name: str): | |
| tok = AutoTokenizer.from_pretrained(model_name) | |
| mdl = AutoModel.from_pretrained(model_name) | |
| mdl.eval() | |
| return tok, mdl | |
| def embed_texts(model_name: str, texts_tuple: tuple): | |
| tokenizer, model = load_model(model_name) | |
| texts = list(texts_tuple) | |
| with torch.no_grad(): | |
| inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") | |
| outputs = model(**inputs) | |
| token_embeddings = outputs.last_hidden_state | |
| mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings) | |
| summed = (token_embeddings * mask).sum(dim=1) | |
| counts = mask.sum(dim=1).clamp(min=1e-9) | |
| embeddings = summed / counts # mean pooling | |
| return embeddings.cpu().numpy() | |
| # ----------------------------- | |
| # Info page (local) via st.query_params | |
| # ----------------------------- | |
| def goto(page: str): | |
| st.query_params["page"] = page | |
| st.rerun() | |
| page = st.query_params.get("page", "demo") | |
| if page == "info": | |
| st.write(""" | |
| # 🧠 Embedding Visualizer – About | |
| This demo shows how **vector embeddings** can capture the meaning of words and place them in a **numerical space** where related items appear close together. | |
| You can: | |
| - Choose from predefined or mixed datasets (e.g., countries, animals, actors, sports) | |
| - Select different embedding models to compare results | |
| - Switch between 2D and 3D visualizations | |
| - Edit the list of words directly and see the updated projection instantly | |
| --- | |
| ## 📌 What are Vector Embeddings? | |
| A **vector embedding** is a way of representing text (words, sentences, or documents) as a list of numbers — a point in a high-dimensional space. | |
| These numbers are produced by a trained **language model** that captures semantic meaning. | |
| In this space: | |
| - Words with **similar meanings** end up **near each other** | |
| - Dissimilar words are placed **far apart** | |
| - The model can detect relationships and groupings that aren’t obvious from spelling or grammar alone | |
| Example: | |
| `"cat"` and `"dog"` will likely be closer to each other than to `"table"`, because the model “knows” they are both animals. | |
| --- | |
| ## 🔍 How the Demo Works | |
| 1. **Embedding step** – Each word is converted into a high-dimensional vector (e.g., 384, 768, or 1024 dimensions depending on the model). | |
| 2. **Dimensionality reduction** – Since humans can’t visualize hundreds of dimensions, the vectors are projected to 2D or 3D using **PCA** (Principal Component Analysis). | |
| 3. **Visualization** – The projected points are plotted, with labels showing the original words. | |
| You can rotate the 3D view to explore groupings. | |
| --- | |
| ## 💡 Typical Applications of Embeddings | |
| - **Semantic search** – Find relevant results even if exact keywords don’t match | |
| - **Clustering & topic discovery** – Group related items automatically | |
| - **Recommendations** – Suggest similar products, movies, or articles | |
| - **Deduplication** – Detect near-duplicate content | |
| - **Analogies** – Explore relationships like *"king" – "man" + "woman" ≈ "queen"* | |
| --- | |
| ## 🚀 Try it Yourself | |
| - Pick a dataset or create your own by editing the list | |
| - Switch models to compare how the embedding space changes | |
| - Toggle between 2D and 3D to explore patterns | |
| """.strip()) | |
| if st.button("⬅ back to demo"): | |
| goto("demo") | |
| st.stop() | |
| # ----------------------------- | |
| # Top compact bar | |
| # ----------------------------- | |
| c1, c2, c3, c4 = st.columns([2, 2, 1, 1]) | |
| with c1: | |
| if "dataset_name" not in st.session_state: | |
| st.session_state.dataset_name = "actors" if "actors" in DATASETS else list(DATASETS.keys())[0] | |
| dataset_name = st.selectbox("dataset", list(DATASETS.keys()), | |
| index=list(DATASETS.keys()).index(st.session_state.dataset_name), | |
| key="dataset_name") | |
| with c2: | |
| if "model_name" not in st.session_state: | |
| st.session_state.model_name = list(MODELS.values())[1] | |
| labels = list(MODELS.keys()) | |
| rev = {v: k for k, v in MODELS.items()} | |
| current_label = rev.get(st.session_state.model_name, labels[0]) | |
| chosen_label = st.selectbox("embedding model", labels, index=labels.index(current_label)) | |
| st.session_state.model_name = MODELS[chosen_label] | |
| with c3: | |
| # Default to 3D on first render; single-click thereafter | |
| radio_kwargs = dict(options=["2D", "3D"], horizontal=True, key="proj_mode") | |
| if "proj_mode" not in st.session_state: | |
| radio_kwargs["index"] = 1 # 3D default | |
| st.radio("projection", **radio_kwargs) | |
| with c4: | |
| if st.button("ℹ info"): | |
| goto("info") | |
| # ----------------------------- | |
| # Two-column layout (left = textarea, right = plot) | |
| # ----------------------------- | |
| left, right = st.columns([1, 2], gap="large") | |
| # Keep textarea synced with dataset selection | |
| if "dataset_text" not in st.session_state: | |
| st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name]) | |
| if "prev_dataset_name" not in st.session_state: | |
| st.session_state.prev_dataset_name = st.session_state.dataset_name | |
| if st.session_state.dataset_name != st.session_state.prev_dataset_name: | |
| st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name]) | |
| st.session_state.prev_dataset_name = st.session_state.dataset_name | |
| with left: | |
| st.text_area( | |
| label="", | |
| key="dataset_text", | |
| height=420, | |
| help="edit words (one per line). changing dataset above refreshes this box." | |
| ) | |
| words = [w.strip() for w in st.session_state.dataset_text.split("\n") if w.strip()] | |
| with right: | |
| if len(words) < 3: | |
| st.info("enter at least three lines to project.") | |
| st.stop() | |
| X = embed_texts(st.session_state.model_name, tuple(words)) | |
| # Capitalized dataset name for the chart title (dataset keys remain lowercase in the UI) | |
| chart_title = st.session_state.dataset_name.title() | |
| if st.session_state.proj_mode == "2D": | |
| coords = PCA(n_components=2).fit_transform(X) | |
| fig = go.Figure( | |
| data=[go.Scatter( | |
| x=coords[:, 0], y=coords[:, 1], | |
| mode="markers+text", | |
| text=words, textposition="top center", | |
| marker=dict(size=9), | |
| )], | |
| layout=go.Layout( | |
| xaxis=dict(title="PC1"), | |
| yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1), | |
| margin=dict(l=0, r=0, b=0, t=40), | |
| ), | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=chart_title, | |
| x=0.5, xanchor='center', yanchor='top', | |
| font=dict(size=20) | |
| ) | |
| ) | |
| else: | |
| coords = PCA(n_components=3).fit_transform(X) | |
| fig = go.Figure( | |
| data=[go.Scatter3d( | |
| x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], | |
| mode="markers+text", | |
| text=words, textposition="top center", | |
| marker=dict(size=6), | |
| )], | |
| layout=go.Layout( | |
| scene=dict( | |
| xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"), | |
| yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"), | |
| zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"), | |
| ), | |
| margin=dict(l=0, r=0, b=0, t=40), | |
| ), | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=chart_title, | |
| x=0.5, xanchor='center', yanchor='top', | |
| font=dict(size=20) | |
| ) | |
| ) | |
| # Simple Plotly rotation: frames + Rotate/Stop buttons | |
| frames = [] | |
| radius = 1.7 | |
| z_eye = 1.0 | |
| for ang in range(0, 360, 4): | |
| rad = np.deg2rad(ang) | |
| frames.append(go.Frame(layout=dict( | |
| scene_camera=dict(eye=dict(x=radius*np.cos(rad), y=radius*np.sin(rad), z=z_eye), | |
| projection=dict(type="perspective")) | |
| ))) | |
| fig.frames = frames | |
| fig.update_layout( | |
| updatemenus=[dict( | |
| type="buttons", showactive=False, x=0.02, y=0.98, | |
| buttons=[ | |
| dict( | |
| label="▶ Rotate", | |
| method="animate", | |
| args=[None, dict(frame=dict(duration=40, redraw=True), | |
| transition=dict(duration=0), | |
| fromcurrent=True, mode="immediate")] | |
| ), | |
| dict( | |
| label="⏹ Stop", | |
| method="animate", | |
| args=[[None], dict(frame=dict(duration=0, redraw=False), | |
| transition=dict(duration=0), | |
| mode="immediate")] | |
| ) | |
| ] | |
| )] | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |