Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| import nltk | |
| from nltk import sent_tokenize | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.cluster import KMeans | |
| import numpy as np | |
| import plotly.graph_objs as go | |
| # Download the punkt tokenizer | |
| nltk.download('punkt_tab') | |
| # Helper function to split text into topics using KMeans clustering and extract top words | |
| def split_text_into_topics(text, n_topics): | |
| sentences = sent_tokenize(text) | |
| vectorizer = TfidfVectorizer(stop_words='english') | |
| X = vectorizer.fit_transform(sentences) | |
| kmeans = KMeans(n_clusters=n_topics, random_state=42) | |
| kmeans.fit(X) | |
| clusters = kmeans.labels_.tolist() | |
| topic_sentences = {i: [] for i in range(n_topics)} | |
| # Store the top word for each cluster | |
| top_words = [] | |
| for i in range(n_topics): | |
| cluster_center = kmeans.cluster_centers_[i] | |
| sorted_indices = np.argsort(cluster_center)[::-1] | |
| top_word_index = sorted_indices[0] | |
| top_word = vectorizer.get_feature_names_out()[top_word_index] | |
| top_words.append(top_word) | |
| for i, sentence in enumerate(sentences): | |
| topic_sentences[clusters[i]].append(sentence) | |
| return topic_sentences, top_words | |
| # Recursive function to split subtopics | |
| def recursive_split(topic_dict, depth, max_depth, subtopics): | |
| if depth >= max_depth: | |
| return | |
| new_topic_dict = {} | |
| for topic, sentences in topic_dict.items(): | |
| if len(sentences) <= 1: | |
| new_topic_dict[topic] = sentences | |
| else: | |
| sub_topics, _ = split_text_into_topics(' '.join(sentences), subtopics) | |
| new_topic_dict[topic] = sub_topics | |
| return new_topic_dict | |
| # Function to convert the tree into edge data for Plotly visualization | |
| def get_edges(tree, parent=None, level=0, top_words=None): | |
| edges = [] | |
| labels = [] | |
| hover_texts = [] | |
| pos = {} | |
| for key, value in tree.items(): | |
| node_label = f'Topic {key}' if parent is None else f'Subtopic {key}' | |
| pos[node_label] = (level, len(labels)) | |
| top_word = top_words[key] if top_words and key < len(top_words) else "N/A" | |
| labels.append(node_label) | |
| hover_texts.append(f"Top Word: {top_word}") | |
| if parent: | |
| edges.append((parent, node_label)) | |
| if isinstance(value, dict): | |
| new_edges, new_labels, new_hover_texts, new_pos = get_edges(value, node_label, level+1, top_words) | |
| edges += new_edges | |
| labels += new_labels | |
| hover_texts += new_hover_texts | |
| pos.update(new_pos) | |
| else: | |
| for i, sentence in enumerate(value): | |
| sentence_label = f"{node_label} - Sentence {i+1}" | |
| pos[sentence_label] = (level+1, len(labels)) | |
| labels.append(sentence_label) | |
| hover_texts.append(sentence) | |
| edges.append((node_label, sentence_label)) | |
| return edges, labels, hover_texts, pos | |
| # Streamlit App layout | |
| # st.title('Interactive Text Topic Tree Generator') | |
| # Upload file | |
| uploaded_file = st.file_uploader("Upload a text file", type="txt") | |
| if uploaded_file is not None: | |
| text = uploaded_file.read().decode('utf-8') | |
| # Select number of main topics and depth of subtopics | |
| n_topics = st.slider('Select number of main topics', 2, 10, 5) | |
| max_depth = st.slider('Select maximum depth of subtopics', 1, 5, 2) | |
| subtopics_per_topic = st.slider('Select number of subtopics per topic', 2, 5, 3) | |
| # Split text into main topics and extract top words | |
| topic_dict, top_words = split_text_into_topics(text, n_topics) | |
| # Recursively split the topics into subtopics | |
| full_tree = recursive_split(topic_dict, 0, max_depth, subtopics_per_topic) | |
| # Get edges, labels, hover texts, and positions for the plot | |
| edges, labels, hover_texts, pos = get_edges(full_tree, top_words=top_words) | |
| # Plot the tree graph using Plotly | |
| edge_x = [] | |
| edge_y = [] | |
| for edge in edges: | |
| x0, y0 = pos[edge[0]] | |
| x1, y1 = pos[edge[1]] | |
| edge_x += [x0, x1, None] | |
| edge_y += [y0, y1, None] | |
| node_x = [pos[label][0] for label in labels] | |
| node_y = [pos[label][1] for label in labels] | |
| # Create edge trace | |
| edge_trace = go.Scatter( | |
| x=edge_x, y=edge_y, | |
| line=dict(width=2, color='Gray'), | |
| hoverinfo='none', | |
| mode='lines' | |
| ) | |
| # Create node trace with hover text showing top words | |
| node_trace = go.Scatter( | |
| x=node_x, y=node_y, | |
| mode='markers+text', | |
| text=labels, | |
| hoverinfo='text', | |
| hovertext=hover_texts, # Adding hover text | |
| marker=dict( | |
| showscale=True, | |
| colorscale='YlGnBu', | |
| size=20, | |
| colorbar=dict( | |
| thickness=15, | |
| title='Depth', | |
| xanchor='left', | |
| titleside='right' | |
| ), | |
| line_width=2 | |
| ) | |
| ) | |
| # Plot the figure | |
| fig = go.Figure(data=[edge_trace, node_trace], | |
| layout=go.Layout( | |
| showlegend=False, | |
| hovermode='closest', | |
| margin=dict(b=0, l=0, r=0, t=0), | |
| xaxis=dict(showgrid=False, zeroline=False), | |
| yaxis=dict(showgrid=False, zeroline=False) | |
| )) | |
| st.plotly_chart(fig) |