Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| from src.subpages.page import Context, Page | |
| class HiddenStatesVisualizer: | |
| def __init__(self, context: Context): | |
| self.context = context | |
| self.df = context.df_tokens_merged.copy() | |
| def _reduce_dim_svd(self, X, n_iter: int, random_state=42): | |
| # Implement your SVD reduction here | |
| pass | |
| def _reduce_dim_pca(self, X, random_state=42): | |
| # Implement your PCA reduction here | |
| pass | |
| def _reduce_dim_umap(self, X, n_neighbors=5, min_dist=0.1, metric="euclidean"): | |
| # Implement your UMAP reduction here | |
| pass | |
| def visualize_hidden_states(self): | |
| st.title("Embeddings") | |
| with st.expander("💡", expanded=True): | |
| st.write( | |
| "For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements signified by a small black border." | |
| ) | |
| col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32]) | |
| dim_algo = "SVD" | |
| n_tokens = 100 | |
| with col1: | |
| st.subheader("Settings") | |
| n_tokens = st.slider( | |
| "#tokens", | |
| key="n_tokens", | |
| min_value=100, | |
| max_value=len(self.df["tokens"].unique()), | |
| step=100, | |
| ) | |
| dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"]) | |
| if dim_algo == "SVD": | |
| svd_n_iter = st.slider( | |
| "#iterations", | |
| key="svd_n_iter", | |
| min_value=1, | |
| max_value=10, | |
| step=1, | |
| ) | |
| elif dim_algo == "UMAP": | |
| umap_n_neighbors = st.slider( | |
| "#neighbors", | |
| key="umap_n_neighbors", | |
| min_value=2, | |
| max_value=100, | |
| step=1, | |
| ) | |
| umap_min_dist = st.number_input( | |
| "Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0 | |
| ) | |
| umap_metric = st.selectbox( | |
| "Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"] | |
| ) | |
| else: | |
| pass | |
| with col2: | |
| sents = self.df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist())) | |
| X = np.array(self.df["hidden_states"].tolist()) | |
| transformed_hidden_states = None | |
| if dim_algo == "SVD": | |
| transformed_hidden_states = self._reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore | |
| elif dim_algo == "PCA": | |
| transformed_hidden_states = self._reduce_dim_pca(X) | |
| elif dim_algo == "UMAP": | |
| transformed_hidden_states = self._reduce_dim_umap( | |
| X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore | |
| ) | |
| assert isinstance(transformed_hidden_states, np.ndarray) | |
| self.df["x"] = transformed_hidden_states[:, 0] | |
| self.df["y"] = transformed_hidden_states[:, 1] | |
| self.df["sent0"] = self.df["ids"].map(lambda x: " ".join(sents[x][0:50].split())) | |
| self.df["sent1"] = self.df["ids"].map(lambda x: " ".join(sents[x][50:100].split())) | |
| self.df["sent2"] = self.df["ids"].map(lambda x: " ".join(sents[x][100:150].split())) | |
| self.df["sent3"] = self.df["ids"].map(lambda x: " ".join(sents[x][150:200].split())) | |
| self.df["sent4"] = self.df["ids"].map(lambda x: " ".join(sents[x][200:250].split())) | |
| self.df["disagreements"] = self.df["labels"] != self.df["preds"] | |
| subset = self.df[:n_tokens] | |
| disagreements_trace = go.Scatter( | |
| x=subset[subset["disagreements"]]["x"], | |
| y=subset[subset["disagreements"]]["y"], | |
| mode="markers", | |
| marker=dict( | |
| size=6, | |
| color="rgba(0,0,0,0)", | |
| line=dict(width=1), | |
| ), | |
| hoverinfo="skip", | |
| ) | |
| st.subheader("Projection Results") | |
| fig = px.scatter( | |
| subset, | |
| x="x", | |
| y="y", | |
| color="labels", | |
| hover_data=["ids", "preds", "sent0", "sent1", "sent2", "sent3", "sent4"], | |
| hover_name="tokens", | |
| title="Colored by label", | |
| ) | |
| fig.add_trace(disagreements_trace) | |
| st.plotly_chart(fig) | |
| fig = px.scatter( | |
| subset, | |
| x="x", | |
| y="y", | |
| color="preds", | |
| hover_data=["ids", "labels", "sent0", "sent1", "sent2", "sent3", "sent4"], | |
| hover_name="tokens", | |
| title="Colored by prediction", | |
| ) | |
| fig.add_trace(disagreements_trace) | |
| st.plotly_chart(fig) | |
| class HiddenStatesPage(Page): | |
| name = "Hidden States" | |
| icon = "grid-3x3" | |
| def render(self, context: Context): | |
| visualizer = HiddenStatesVisualizer(context) | |
| visualizer.visualize_hidden_states() |