Spaces:
Paused
Paused
| from logging import getLogger | |
| from pathlib import Path | |
| import pandas as pd | |
| import plotly.express as px | |
| import streamlit as st | |
| from st_aggrid import AgGrid, ColumnsAutoSizeMode, GridOptionsBuilder | |
| from streamlit_plotly_events import plotly_events | |
| from utilities import initialization | |
| initialization() | |
| # @st.cache(show_spinner=False) | |
| # def initialize_state(): | |
| # with st.spinner("Loading app..."): | |
| # if 'model' not in st.session_state: | |
| # model = Top2Vec.load('models/model.pkl') | |
| # model._check_model_status() | |
| # model.hierarchical_topic_reduction(num_topics=20) | |
| # | |
| # st.session_state.model = model | |
| # st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav') | |
| # logger.info("loading data...") | |
| # | |
| # if 'data' not in st.session_state: | |
| # logger.info("loading data...") | |
| # data = pd.read_csv(proj_dir / 'data' / 'data.csv') | |
| # data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}') | |
| # st.session_state.data = data | |
| # st.session_state.selected_data = data | |
| # st.session_state.all_topics = list(data.topic_id.unique()) | |
| # | |
| # if 'topics' not in st.session_state: | |
| # logger.info("loading topics...") | |
| # topics = pd.read_csv(proj_dir / 'data' / 'topics.csv') | |
| # topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}') | |
| # st.session_state.topics = topics | |
| def reset(): | |
| logger.info("Resetting...") | |
| st.session_state.selected_data = st.session_state.data | |
| st.session_state.selected_points = [] | |
| def filter_df(): | |
| if st.session_state.selected_points: | |
| points_df = pd.DataFrame(st.session_state.selected_points).loc[:, ['x', 'y']] | |
| st.session_state.selected_data = st.session_state.data.merge(points_df, on=['x', 'y']) | |
| logger.info(f"Updates selected_data: {len(st.session_state.selected_data)}") | |
| else: | |
| logger.info(f"Lame") | |
| def reset(): | |
| st.session_state.selected_data = st.session_state.data | |
| st.session_state.selected_points = [] | |
| def main(): | |
| st.write(""" | |
| # Topic Modeling | |
| This shows a 2d representation of documents embeded in a semantic space. Each dot is a document | |
| and the dots close represent documents that are close in meaning. | |
| Zoom in and explore a topic of your choice. You can see the documents you select with the `lasso` or `box` | |
| tool below in the corresponding tabs.""" | |
| ) | |
| st.button("Reset", help="Will Reset the selected points and the selected topics", on_click=reset) | |
| data_to_model = st.session_state.data.sort_values(by='topic_id', | |
| ascending=True) # to make legend sorted https://bioinformatics.stackexchange.com/a/18847 | |
| data_to_model['topic_id'].replace(st.session_state.topic_str_to_word, inplace=True) | |
| fig = px.scatter(data_to_model, x='x', y='y', color='topic_id', template='plotly_dark', | |
| hover_data=['id', 'topic_id', 'x', 'y']) | |
| st.session_state.selected_points = plotly_events(fig, select_event=True, click_event=False) | |
| filter_df() | |
| tab1, tab2 = st.tabs(["Docs", "Topics"]) | |
| with tab1: | |
| if st.session_state.selected_points: | |
| filter_df() | |
| cols = ['id', 'topic_id', 'documents'] | |
| data = st.session_state.selected_data[cols] | |
| data['topic_word'] = data.topic_id.replace(st.session_state.topic_str_to_word) | |
| ordered_cols = ['id', 'topic_id', 'topic_word', 'documents'] | |
| builder = GridOptionsBuilder.from_dataframe(data[ordered_cols]) | |
| builder.configure_pagination() | |
| go = builder.build() | |
| AgGrid(data[ordered_cols], theme='streamlit', gridOptions=go, | |
| columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS) | |
| else: | |
| st.markdown('Select points in the graph with the `lasso` or `box` select tools to populate this table.') | |
| def get_topics_counts() -> pd.DataFrame: | |
| topic_counts = st.session_state.selected_data["topic_id"].value_counts().to_frame() | |
| merged = topic_counts.merge(st.session_state.topics, left_index=True, right_on='topic_id') | |
| cleaned = merged.drop(['topic_id_y'], axis=1).rename({'topic_id_x': 'topic_count'}, axis=1) | |
| cols = ['topic_id'] + [col for col in cleaned.columns if col != 'topic_id'] | |
| return cleaned[cols] | |
| with tab2: | |
| if st.session_state.selected_points: | |
| filter_df() | |
| cols = ['topic_id', 'topic_count', 'topic_0'] | |
| topic_counts = get_topics_counts() | |
| # st.write(topic_counts.columns) | |
| builder = GridOptionsBuilder.from_dataframe(topic_counts[cols]) | |
| builder.configure_pagination() | |
| builder.configure_column('topic_0', header_name='Topic Word', wrap_text=True) | |
| go = builder.build() | |
| AgGrid(topic_counts.loc[:, cols], theme='streamlit', gridOptions=go, | |
| columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW) | |
| else: | |
| st.markdown('Select points in the graph with the `lasso` or `box` select tools to populate this table.') | |
| if __name__ == "__main__": | |
| # Setting up Logger and proj_dir | |
| logger = getLogger(__name__) | |
| proj_dir = Path(__file__).parents[2] | |
| # For max width tables | |
| pd.set_option('display.max_colwidth', 0) | |
| # Streamlit settings | |
| # st.set_page_config(layout="wide") | |
| md_title = "# Document Explorer π" | |
| st.markdown(md_title) | |
| st.sidebar.markdown(md_title) | |
| # initialize_state() | |
| main() | |