Spaces:
Runtime error
Runtime error
| import os | |
| # stop tensorflow from printing novels to stdout | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| import pickle | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| import streamlit as st | |
| import tensorflow as tf | |
| import tensorflow_hub as hub | |
| from sklearn.cluster import DBSCAN | |
| def read_stops(p: str): | |
| """ | |
| Read in the .csv file of metro stops | |
| :param p: The path to the .csv file of metro stops | |
| """ | |
| return pd.read_csv(p) | |
| def read_encodings(p: str) -> tf.Tensor: | |
| """ | |
| Unpickle the Universal Sentence Encoder v4 encodings | |
| and return them | |
| This function doesn't make any attempt to patch the security holes in `pickle`. | |
| :param p: Path to the encodings | |
| :returns: A Tensor of the encodings with shape (number of sentences, 512) | |
| """ | |
| with open(p, 'rb') as f: | |
| encodings = pickle.load(f) | |
| return encodings | |
| def cluster_encodings(encodings: tf.Tensor) -> np.ndarray: | |
| """ | |
| Cluster the sentence encodings using DBSCAN. | |
| :param encodings: A Tensor of sentence encodings with shape | |
| (number of sentences, 512) | |
| :returns: a NumPy array of the cluster labels | |
| """ | |
| # I know the hyperparams I want from the EDA I did in the notebook | |
| clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings) | |
| return clusterer.labels_ | |
| def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray: | |
| """ | |
| Cluster the metro stops by their latitude and longitude using DBSCAN. | |
| :param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns | |
| :returns: a NumPy array of the cluster labels | |
| """ | |
| # I know the hyperparams I want from the EDA I did in the notebook | |
| clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']]) | |
| return clusterer.labels_ | |
| def plot_example(df: pd.DataFrame, labels: np.ndarray): | |
| """ | |
| Plot the geographic clustering | |
| :param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns | |
| :param labels: a NumPy array of the cluster labels | |
| """ | |
| px.set_mapbox_access_token(st.secrets['mapbox_token']) | |
| labels = labels.astype('str') | |
| fig = px.scatter_mapbox(df, lon='longitude', lat='latitude', | |
| hover_name='display_name', | |
| color=labels, | |
| zoom=8, | |
| color_discrete_sequence=px.colors.qualitative.Dark24) | |
| return fig | |
| def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray): | |
| """ | |
| Plot the metro stops and color them based on their names | |
| :param df: A Pandas DataFrame of stops that has 'latitude` and 'longitude` columns | |
| :param labels: a NumPy array of the cluster labels | |
| """ | |
| px.set_mapbox_access_token(st.secrets['mapbox_token']) | |
| venice_blvd = {'lat': 34.008350, | |
| 'lon': -118.425362} | |
| labels = labels.astype('str') | |
| fig = px.scatter_mapbox(df, lat='latitude', lon='longitude', | |
| color=labels, | |
| hover_name='display_name', | |
| center=venice_blvd, | |
| zoom=12, | |
| color_discrete_sequence=px.colors.qualitative.Dark24) | |
| # fig.show() | |
| return fig | |
| def main(data_path: str, enc_path: str): | |
| df = read_stops(data_path) | |
| # Cluster based on lat/lon | |
| example_labels = cluster_lat_lon(df) | |
| example_fig = plot_example(df, example_labels) | |
| # Cluster based on the name of the stop | |
| encodings = read_encodings(enc_path) | |
| encoding_labels = cluster_encodings(encodings) | |
| venice_fig = plot_venice_blvd(df, encoding_labels) | |
| # Display the plots with Streamlit | |
| st.write('# Cluster the stops by their position') | |
| st.write("""First, I clustered the | |
| stops by their geographic location. | |
| The DBSCAN algorithm finds three clusters. | |
| Points labeled `-1` aren't part of any cluster. | |
| Clicking on `-1` in the legend will turn off those points.""") | |
| st.plotly_chart(example_fig, use_container_width=True) | |
| st.write('# Cluster the stops by their name') | |
| st.write("""I encoded the names of all the stops using the Universal Sentence Encoder v4. | |
| I then clustered those encodings so that I could group the stops based on their names | |
| instead of their geographic position. | |
| As I expected, stops on the same road end up close enough to each other that DBSCAN can cluster them together. | |
| Sometimes, however, a stop has a name that means something to the encoder. | |
| When that happens, the encoding ends up too far away from the rest of the stops on that road. | |
| For example, the stops on Venice Blvd get clustered together, | |
| but the stop "Venice / Lincoln" ends up somewhere else. | |
| I assume it ends up somewhere else because the encoder recognizes "Lincoln" | |
| and that meaning overpowers the "Venice" meaning enough that the encoding | |
| is too far away from the rest of the "Venice" stops. | |
| A few other examples on Venice Blvd are "Saint Andrews," "Harvard," and "Beethoven." | |
| There are also a few that I don't ascribe much meaning to, such as "Girard" and "Robertson." | |
| There's a lot more to dig into here but I'll leave it there for now. | |
| My mind first jumps to adversarial prompts that use famous names to move the encoding | |
| around in the encoding space. | |
| """) | |
| st.plotly_chart(venice_fig, use_container_width=True) | |
| if __name__ == '__main__': | |
| import argparse | |
| p = argparse.ArgumentParser() | |
| p.add_argument('--data_path', | |
| nargs='?', | |
| default='data/stops.csv', | |
| help="Path to the dataset of LA Metro stops. Defaults to 'data/stops.csv'") | |
| p.add_argument('--enc_path', | |
| nargs='?', | |
| default='data/encodings.pkl', | |
| help="Path to the pickled encodings. Defaults to 'data/encodings.pkl'") | |
| args = p.parse_args() | |
| main(**vars(args)) | |