Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from pyvis.network import Network | |
| import plotly.express as px | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sentence_transformers import SentenceTransformer | |
| from bertopic import BERTopic | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| import pandas as pd | |
| import numpy as np | |
| import networkx as nx | |
| import textwrap | |
| import logging | |
| # from streamlit.ReportThread import REPORT_CONTEXT_ATTR_NAME | |
| from threading import current_thread | |
| from contextlib import contextmanager | |
| from io import StringIO | |
| import sys | |
| import time | |
| logger = logging.getLogger('main') | |
| def reset_default_topic_sliders(min_topic_size, n_gram_range): | |
| st.session_state['min_topic_size'] = min_topic_size | |
| st.session_state['n_gram_range'] = n_gram_range | |
| def reset_default_threshold_slider(threshold): | |
| st.session_state['threshold'] = threshold | |
| def load_data(uploaded_file): | |
| data = pd.read_csv(uploaded_file) | |
| return data | |
| def embedding_gen(data): | |
| logger.info('Calculating Embeddings') | |
| return SentenceTransformer('./sentence-transformers_allenai-specter').encode(data['Text']) | |
| def load_bertopic_model(min_topic_size, n_gram_range): | |
| logger.info('Loading BERTopic model') | |
| return BERTopic( | |
| vectorizer_model=CountVectorizer( | |
| stop_words='english', ngram_range=n_gram_range | |
| ), | |
| min_topic_size=min_topic_size, | |
| verbose=True | |
| ) | |
| def topic_modeling(data, min_topic_size, n_gram_range): | |
| """Topic modeling using BERTopic | |
| """ | |
| logger.info('Calculating Topic Model') | |
| topic_model = load_bertopic_model(min_topic_size, n_gram_range) | |
| # Train the topic model | |
| topic_data = data.copy() | |
| topic_data["Topic"], topic_data["Probs"] = topic_model.fit_transform( | |
| data['Text'], embeddings=embedding_gen(data)) | |
| # Merge topic results | |
| topic_df = topic_model.get_topic_info() | |
| topic_df.columns = ['Topic', 'Topic_Count', 'Topic_Name'] | |
| topic_df = topic_df.sort_values(by='Topic_Count', ascending=False) | |
| topic_data = topic_data.merge(topic_df, on='Topic', how='left') | |
| # Topics | |
| # Optimization: Only take top 10 largest topics | |
| topics = topic_df.head(10).set_index('Topic').to_dict(orient='index') | |
| logger.info('Topic Modeling Complete') | |
| return topic_data, topic_model, topics | |
| def cosine_sim(data): | |
| logger.info('Cosine similarity') | |
| cosine_sim_matrix = cosine_similarity(embedding_gen(data)) | |
| # Take only upper triangular matrix | |
| cosine_sim_matrix = np.triu(cosine_sim_matrix, k=1) | |
| return cosine_sim_matrix | |
| def calc_max_connections(num_papers, ratio): | |
| n = ratio*num_papers | |
| return n*(n-1)/2 | |
| def calc_neighbors(cosine_sim_matrix, threshold): | |
| neighbors = np.argwhere(cosine_sim_matrix >= threshold).tolist() | |
| return neighbors, len(neighbors) | |
| def calc_optimal_threshold(cosine_sim_matrix, max_connections): | |
| """Calculates the optimal threshold for the cosine similarity matrix. | |
| Allows a max of max_connections | |
| """ | |
| logger.info('Calculating optimal threshold') | |
| thresh_sweep = np.arange(0.05, 1.05, 0.05)[::-1] | |
| for idx, threshold in enumerate(thresh_sweep): | |
| _, num_neighbors = calc_neighbors(cosine_sim_matrix, threshold) | |
| if num_neighbors > max_connections: | |
| break | |
| return round(thresh_sweep[idx-1], 2).item(), round(thresh_sweep[idx], 2).item() | |
| def nx_hash_func(nx_net): | |
| """Hash function for NetworkX graphs. | |
| """ | |
| return (list(nx_net.nodes()), list(nx_net.edges())) | |
| def pyvis_hash_func(pyvis_net): | |
| """Hash function for pyvis graphs. | |
| """ | |
| return (pyvis_net.nodes, pyvis_net.edges) | |
| def network_plot(topic_data, topics, neighbors): | |
| """Creates a network plot of connected papers. Colored by Topic Model topics. | |
| """ | |
| logger.info('Calculating Network Plot') | |
| nx_net = nx.Graph() | |
| pyvis_net = Network(height='750px', width='100%', bgcolor='#222222') | |
| # Add Nodes | |
| nodes = [ | |
| ( | |
| row.Index, | |
| { | |
| 'group': row.Topic, | |
| 'label': row.Index, | |
| 'title': text_processing(row.Text), | |
| 'size': 20, 'font': {'size': 20, 'color': 'white'} | |
| } | |
| ) | |
| for row in topic_data.itertuples() | |
| ] | |
| nx_net.add_nodes_from(nodes) | |
| assert(nx_net.number_of_nodes() == len(topic_data)) | |
| # Add Edges | |
| nx_net.add_edges_from(neighbors) | |
| assert(nx_net.number_of_edges() == len(neighbors)) | |
| # Optimization: Remove Isolated nodes | |
| nx_net.remove_nodes_from(list(nx.isolates(nx_net))) | |
| # Add Legend Nodes | |
| step = 150 | |
| x = -2000 | |
| y = -500 | |
| legend_nodes = [ | |
| ( | |
| len(topic_data)+idx, | |
| { | |
| 'group': key, 'label': ', '.join(value['Topic_Name'].split('_')[1:]), | |
| 'size': 30, 'physics': False, 'x': x, 'y': f'{y + idx*step}px', | |
| # , 'fixed': True, | |
| 'shape': 'box', 'widthConstraint': 1000, 'font': {'size': 40, 'color': 'black'} | |
| } | |
| ) | |
| for idx, (key, value) in enumerate(topics.items()) | |
| ] | |
| nx_net.add_nodes_from(legend_nodes) | |
| # Plot the Pyvis graph | |
| pyvis_net.from_nx(nx_net) | |
| return nx_net, pyvis_net | |
| def text_processing(text): | |
| text = text.split('[SEP]') | |
| text = '<br><br>'.join(text) | |
| text = '<br>'.join(textwrap.wrap(text, width=50))[:500] | |
| text = text + '...' | |
| return text | |
| def network_centrality(topic_data, centrality, centrality_option): | |
| """Calculates the centrality of the network | |
| """ | |
| logger.info('Calculating Network Centrality') | |
| # Sort Top 10 Central nodes | |
| central_nodes = sorted( | |
| centrality.items(), key=lambda item: item[1], reverse=True) | |
| central_nodes = pd.DataFrame(central_nodes, columns=[ | |
| 'node', centrality_option]).set_index('node') | |
| joined_data = topic_data.join(central_nodes) | |
| top_central_nodes = joined_data.sort_values( | |
| centrality_option, ascending=False).head(10) | |
| # Prepare for plot | |
| top_central_nodes = top_central_nodes.reset_index() | |
| top_central_nodes['index'] = top_central_nodes['index'].astype(str) | |
| top_central_nodes['Topic_Name'] = top_central_nodes['Topic_Name'].apply( | |
| lambda x: ', '.join(x.split('_')[1:])) | |
| top_central_nodes['Text'] = top_central_nodes['Text'].apply( | |
| text_processing) | |
| # Plot the Top 10 Central nodes | |
| fig = px.bar(top_central_nodes, x=centrality_option, y='index', | |
| color='Topic_Name', hover_data=['Text'], orientation='h') | |
| fig.update_layout(yaxis={'categoryorder': 'total ascending', 'visible': False, 'showticklabels': False}, | |
| font={'size': 15}, height=800) | |
| return fig | |
| # Progress bar printer | |
| # https://github.com/BugzTheBunny/streamlit_logging_output_example/blob/main/app.py | |
| # https://discuss.streamlit.io/t/cannot-print-the-terminal-output-in-streamlit/6602/34 | |
| # @contextmanager | |
| # def st_redirect(src, dst): | |
| # placeholder = st.empty() | |
| # output_func = getattr(placeholder, dst) | |
| # with StringIO() as buffer: | |
| # old_write = src.write | |
| # def new_write(b): | |
| # if getattr(current_thread(), REPORT_CONTEXT_ATTR_NAME, None): | |
| # buffer.write(b) | |
| # time.sleep(1) | |
| # buffer.seek(0) # returns pointer to 0 position | |
| # output_func(b) | |
| # else: | |
| # old_write(b) | |
| # try: | |
| # src.write = new_write | |
| # yield | |
| # finally: | |
| # src.write = old_write | |
| # @contextmanager | |
| # def st_stdout(dst): | |
| # "this will show the prints" | |
| # with st_redirect(sys.stdout, dst): | |
| # yield | |
| # @contextmanager | |
| # def st_stderr(dst): | |
| # "This will show the logging" | |
| # with st_redirect(sys.stderr, dst): | |
| # yield | |