Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import streamlit as st | |
| from Functionalities import NLP_Helper | |
| from Functionalities.TopicClustering import TopicClustering | |
| from streamlit_extras.dataframe_explorer import dataframe_explorer | |
| class TopicClusterView: | |
| def __init__(self): | |
| self.n_neighbors = 10 | |
| self.topic_cluster = None | |
| self.representation_model = None | |
| self.sentence_model = None | |
| self.text_col = None | |
| self.text_df = None | |
| self.text_file = None | |
| st.session_state.topic_cluster = None \ | |
| if 'topic_cluster' not in st.session_state else st.session_state.topic_cluster | |
| st.set_page_config(page_title='Topic Clustering', layout="wide") | |
| st.header("Topic Clustering") | |
| # st.write(f"This page tries to predict the suitable ad group for new keywords " | |
| # f"based on the keywords already existing in the campaign.") | |
| def input_params(self) -> None: | |
| """ | |
| Takes csv file input, name of text col, select option for sentence model and representation model | |
| :return: | |
| """ | |
| self.text_file = st.file_uploader(label="Upload the CSV file containing the texts to cluster") | |
| if self.text_file: | |
| self.text_df = pd.read_csv(self.text_file) | |
| self.text_col = st.selectbox( | |
| label=f"Choose the column to use for topic clustering in **{self.text_file.name}**", | |
| options=self.text_df.columns | |
| ) | |
| self.sentence_model = st.selectbox( | |
| label=f"Choose the text embedding model", | |
| options=NLP_Helper.TRANSFORMERS, | |
| help="; ".join(NLP_Helper.TRANSFORMERS_INFO) | |
| ) | |
| self.representation_model = st.selectbox( | |
| label=f"Choose the representation model", | |
| options=NLP_Helper.BERTOPIC_REPRESENTATIONS, | |
| ) | |
| st.button("Cluster", on_click=self.run_clustering) | |
| def run_clustering(self) -> None: | |
| self.topic_cluster = TopicClustering(keyword_df=self.text_df, text_col=self.text_col, | |
| representation_model=self.representation_model, | |
| sentence_model=self.sentence_model) | |
| self.topic_cluster.topic_cluster_bert() | |
| st.session_state.topic_cluster = self.topic_cluster | |
| def show_and_download_df(self): | |
| if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None): | |
| filtered_df = dataframe_explorer(st.session_state.topic_cluster.keyword_df) | |
| st.dataframe(filtered_df) | |
| with st.expander("Rename Topics"): | |
| for topic_name in st.session_state.topic_cluster.topic_names: | |
| cur_topic_col, new_topic_col = st.columns(2) | |
| with cur_topic_col: | |
| cur_topic_col.write(topic_name) | |
| with new_topic_col: | |
| st.session_state.topic_cluster.topic_name_mapping[topic_name] = \ | |
| st.text_input("New topic name", topic_name) | |
| if st.button("Update Topic Names"): | |
| st.session_state.topic_cluster.update_topic_names() | |
| st.experimental_rerun() | |
| st.download_button( | |
| "Press to Download as CSV", | |
| st.session_state.topic_cluster.keyword_df.to_csv(index=False).encode('utf-8'), | |
| "Clustered.csv", | |
| "text/csv", | |
| key='download-csv' | |
| ) | |
| with st.expander("Download as CSV for Bulk upload in Google Ads"): | |
| campaign_name = st.text_input("Campaign Name", "Demo Campaign") | |
| st.dataframe(st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name)) | |
| st.download_button( | |
| "Download as CSV for Bulk upload in Google Ads", | |
| st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name).to_csv( | |
| index=False).encode('utf-8'), | |
| f"{campaign_name}_keywords_upload.csv", | |
| "text/csv", | |
| key='download-google-csv' | |
| ) | |
| def visualize_clusters(self): | |
| if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None): | |
| self.n_neighbors = st.slider(label='Size of the local neighborhood', min_value=2, max_value=100, step=1) | |
| if st.button("Visualize Topic Clusters"): | |
| if (st.session_state.topic_cluster is not None) and ( | |
| st.session_state.topic_cluster.topic_model is not None): | |
| fig = st.session_state.topic_cluster.visualize_documents(n_neighbors=self.n_neighbors) | |
| fig.update_layout(title=None) | |
| st.plotly_chart(fig, use_container_width=True, theme=None) | |
| def visualize_topic_distribution(self): | |
| if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None): | |
| if (st.session_state.topic_cluster is not None) and ( | |
| st.session_state.topic_cluster.topic_model is not None): | |
| fig = st.session_state.topic_cluster.visualize_topic_distribution() | |
| st.plotly_chart(fig, use_container_width=True, theme=None) | |
| if __name__ == '__main__': | |
| topic_cluster_view = TopicClusterView() | |
| # tab1, tab2, tab3 = st.tabs(['Clustering Process', 'Cluster Visualization', 'Topic Distribution']) | |
| tab1, tab2 = st.tabs(['Clustering Process', 'Cluster Visualization']) | |
| with tab1: | |
| topic_cluster_view.input_params() | |
| topic_cluster_view.show_and_download_df() | |
| with tab2: | |
| topic_cluster_view.visualize_clusters() | |
| # with tab3: | |
| # topic_cluster_view.visualize_topic_distribution() | |