| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import marimo |
|
|
| __generated_with = "0.9.33" |
| app = marimo.App(width="medium") |
|
|
|
|
| @app.cell |
| def __(): |
| import marimo as mo |
|
|
| return (mo,) |
|
|
|
|
| @app.cell(hide_code=True) |
| def __(mo): |
| mo.md( |
| r""" |
| # Visualizing text embeddings using MotherDuck and marimo |
| |
| > Text embeddings have become a crucial tool in AI/ML applications, allowing us to convert text into numerical vectors that capture semantic meaning. These vectors are often used for semantic search, but in ~~this blog post~~ marimo app, we'll explore how to visualize and explore text embeddings interactively using MotherDuck and marimo. |
| |
| This app lets you visualize and explore text embeddings from Hacker News posts about **databases**. You can: |
| |
| - See how different posts cluster together based on semantic similarity |
| - Adjust clustering parameters in real-time |
| - Explore relationships between posts through an interactive visualization |
| |
| !!! Info |
| **This marimo application based on [this blog](https://motherduck.com/blog/MotherDuck-Visualize-Embeddings-Marimo/).** We recommend looking through the blog first. |
| """ |
| ) |
| return |
|
|
|
|
| @app.cell(hide_code=True) |
| def __(mo): |
| mo.md( |
| """ |
| ## Connecting to MotherDuck and Loading Sample Data |
| |
| This data has already been pre-computed, but you can fork and edit this notebook to run with your own data! |
| |
| ```sql |
| ATTACH IF NOT EXISTS 'md:my_db' |
| SELECT * FROM my_db.demo_with_embeddings; |
| ``` |
| """ |
| ) |
| return |
|
|
|
|
| @app.cell |
| def __(mo): |
| _df = mo.sql( |
| """ |
| ATTACH IF NOT EXISTS 'md:my_db' |
| """ |
| ) |
| return (my_db,) |
|
|
|
|
| @app.cell |
| def __(mo): |
| _df = mo.sql( |
| """ |
| -- Commented out as we have already run the embeddings for showcasing purposes. |
| |
| -- CREATE OR REPLACE TABLE my_db.demo_embedding_data AS |
| -- SELECT DISTINCT ON (url) * -- Remove duplicate URLs |
| -- FROM 'hf://datasets/julien040/hacker-news-posts/story.parquet' |
| -- WHERE contains(title, 'database') -- Filter for posts about databases |
| -- AND score > 5 -- Only include popular posts |
| -- LIMIT 50000; |
| """ |
| ) |
| return |
|
|
|
|
| @app.cell |
| def __(demo_with_embeddings, mo, my_db): |
| embeddings = mo.sql( |
| f""" |
| -- Commented out as we have already run the embeddings for showcasing purposes. |
| -- CREATE TABLE my_db.demo_with_embeddings AS |
| -- SELECT *, embedding(title) as text_embedding |
| -- FROM my_db.demo_embedding_data |
| -- LIMIT 1500; |
| |
| SELECT title, text_embedding, * EXCLUDE(id, title, text_embedding, comments) FROM my_db.demo_with_embeddings; |
| """ |
| ) |
| return (embeddings,) |
|
|
|
|
| @app.cell |
| def __(mo): |
| mo.md( |
| """ |
| ## Making Sense of High-Dimensional Data |
| |
| Text embeddings typically have hundreds of dimensions (512 in our case), making them impossible to visualize directly. We'll use two techniques to make them interpretable: |
| |
| 1. **Dimensionality Reduction**: Convert our 512D vectors into 2D points while preserving relationships between texts |
| 2. **Clustering**: Group similar texts together into clusters |
| """ |
| ) |
| return |
|
|
|
|
| @app.cell(hide_code=True) |
| def __(cluster_points, mo, reduce_dimensions): |
| def md_help(cls): |
| import inspect |
|
|
| return f"def {cls.__name__} {inspect.signature(cls)}:\n {cls.__doc__}" |
|
|
| mo.accordion( |
| { |
| "`reduce_dimensions`": md_help(reduce_dimensions), |
| "`cluster_points`": md_help(cluster_points), |
| } |
| ) |
| return (md_help,) |
|
|
|
|
| @app.cell |
| def __(np): |
| def reduce_dimensions(np_array, metric="cosine"): |
| """ |
| Reduce the dimensions of embeddings to a 2D space. |
| |
| Here we use the UMAP algorithm. UMAP preserves both local and |
| global structure of the high-dimensional data. |
| """ |
| import umap |
|
|
| reducer = umap.UMAP( |
| n_components=2, |
| metric=metric, |
| n_neighbors=80, |
| min_dist=0.1, |
| ) |
| return reducer.fit_transform(np_array) |
|
|
| def cluster_points(np_array, min_cluster_size=4, max_cluster_size=50): |
| """ |
| Cluster the embeddings. |
| |
| |
| Here we use the HDBSCAN algorithm. We first reduce dimensionality to 50D with |
| PCA to speed up clustering, while still preserving most of the important information. |
| """ |
| import hdbscan |
| from sklearn.decomposition import PCA |
|
|
| pca = PCA(n_components=50) |
| np_array = pca.fit_transform(np_array) |
|
|
| hdb = hdbscan.HDBSCAN( |
| min_samples=3, |
| min_cluster_size=min_cluster_size, |
| max_cluster_size=max_cluster_size, |
| ).fit(np_array) |
|
|
| return np.where( |
| hdb.labels_ == -1, "outlier", "cluster_" + hdb.labels_.astype(str) |
| ) |
|
|
| return cluster_points, reduce_dimensions |
|
|
|
|
| @app.cell |
| def __(mo): |
| cluster_size_slider = mo.ui.range_slider( |
| start=1, |
| stop=80, |
| value=(4, 50), |
| step=1, |
| show_value=True, |
| debounce=True, |
| label="Cluster Size (min, max)", |
| ) |
| metric_dropdown = mo.ui.dropdown( |
| ["cosine", "euclidean", "manhattan"], |
| value="cosine", |
| label="Distance Metric", |
| ) |
| return cluster_size_slider, metric_dropdown |
|
|
|
|
| @app.cell |
| def __(mo): |
| mo.md( |
| r""" |
| ## Processing the Data |
| |
| Now we'll transform our high-dimensional embeddings into something we can visualize, using `reduce_dimensions` and `cluster_points`. More details on this step [in the blog](https://motherduck.com/blog/MotherDuck-Visualize-Embeddings-Marimo/). |
| """ |
| ) |
| return |
|
|
|
|
| @app.cell |
| def __( |
| cluster_points, |
| cluster_size_slider, |
| embeddings, |
| metric_dropdown, |
| mo, |
| reduce_dimensions, |
| ): |
| with mo.status.spinner("Clustering points...") as _s: |
| import numba |
|
|
| embeddings_array = embeddings["text_embedding"].to_numpy() |
| hdb_labels = cluster_points( |
| embeddings_array, |
| min_cluster_size=cluster_size_slider.value[0], |
| max_cluster_size=cluster_size_slider.value[1], |
| ) |
| _s.update("Reducing dimensionality...") |
| embeddings_2d = reduce_dimensions( |
| embeddings_array, metric=metric_dropdown.value |
| ) |
| mo.show_code() |
| return embeddings_2d, embeddings_array, hdb_labels, numba |
|
|
|
|
| @app.cell |
| def __(cluster_size_slider, metric_dropdown, mo): |
| mo.hstack([cluster_size_slider, metric_dropdown]) |
| return |
|
|
|
|
| @app.cell |
| def __(embeddings, embeddings_2d, hdb_labels, pl): |
| data = embeddings.lazy() |
| data = data.with_columns( |
| text_embedding_2d_1=embeddings_2d[:, 0], |
| text_embedding_2d_2=embeddings_2d[:, 1], |
| cluster=hdb_labels, |
| ) |
| data = data.unique(subset=["url"], maintain_order=True) |
| data = data.drop(["text_embedding"]) |
| data = data.filter(pl.col("cluster") != "outlier") |
| data = data.collect() |
| return (data,) |
|
|
|
|
| @app.cell |
| def __(data): |
| data.select( |
| "title", "cluster", "text_embedding_2d_1", "text_embedding_2d_2", "score" |
| ) |
| return |
|
|
|
|
| @app.cell |
| def __(alt, data, mo): |
| chart = ( |
| alt.Chart(data) |
| .mark_point() |
| .encode( |
| x=alt.X("text_embedding_2d_1").scale(zero=False), |
| y=alt.Y("text_embedding_2d_2").scale(zero=False), |
| color="cluster", |
| tooltip=["title", "score", "cluster"], |
| ) |
| ) |
| chart = mo.ui.altair_chart(chart) |
| mo.show_code() |
| return (chart,) |
|
|
|
|
| @app.cell(hide_code=True) |
| def __(mo): |
| mo.md( |
| r""" |
| ## Creating an Interactive Visualization |
| |
| We will plot the 2D representation of the text embeddings, colored by the clusters identified by HDBSCAN. You can select points on the chart to explore the text embeddings further. 👇 |
| """ |
| ) |
| return |
|
|
|
|
| @app.cell |
| def __(chart): |
| chart |
| return |
|
|
|
|
| @app.cell |
| def __(chart): |
| chart.value |
| return |
|
|
|
|
| @app.cell |
| def __(mo): |
| |
| mo.Html("<div style='height: 400px;'></div>") |
| return |
|
|
|
|
| @app.cell |
| def __(): |
| |
| import polars as pl |
| import duckdb |
| import pyarrow |
|
|
| |
| import altair as alt |
|
|
| |
| import numpy as np |
|
|
| return alt, duckdb, np, pl, pyarrow |
|
|
|
|
| if __name__ == "__main__": |
| app.run() |
|
|