Spaces:
Runtime error
Runtime error
File size: 5,492 Bytes
5d0296f 588731e 5d0296f ba1c1a9 5d0296f 588731e 1c98db8 7c47743 5d0296f 7c47743 5d0296f 7c47743 5d0296f 7c47743 5d0296f d7e685d 7c47743 d7e685d 7c47743 d7e685d 5d0296f d7e685d 7c47743 d7e685d 7c47743 5d0296f d7e685d 5d0296f 7c47743 d7e685d 7c47743 d7e685d 5d0296f d7e685d 7c47743 d7e685d 5d0296f 7c47743 5d0296f 789d822 5d0296f 7c47743 5d0296f 7c47743 d7e685d 7c47743 d7e685d 7c47743 d7e685d 5d0296f d7e685d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import streamlit as st
import networkx as nx
import matplotlib.pyplot as plt
import nltk
from nltk import sent_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
import numpy as np
import plotly.graph_objs as go
# Download the punkt tokenizer
nltk.download('punkt_tab')
# Helper function to split text into topics using KMeans clustering and extract top words
def split_text_into_topics(text, n_topics):
sentences = sent_tokenize(text)
vectorizer = TfidfVectorizer(stop_words='english')
X = vectorizer.fit_transform(sentences)
kmeans = KMeans(n_clusters=n_topics, random_state=42)
kmeans.fit(X)
clusters = kmeans.labels_.tolist()
topic_sentences = {i: [] for i in range(n_topics)}
# Store the top word for each cluster
top_words = []
for i in range(n_topics):
cluster_center = kmeans.cluster_centers_[i]
sorted_indices = np.argsort(cluster_center)[::-1]
top_word_index = sorted_indices[0]
top_word = vectorizer.get_feature_names_out()[top_word_index]
top_words.append(top_word)
for i, sentence in enumerate(sentences):
topic_sentences[clusters[i]].append(sentence)
return topic_sentences, top_words
# Recursive function to split subtopics
def recursive_split(topic_dict, depth, max_depth, subtopics):
if depth >= max_depth:
return
new_topic_dict = {}
for topic, sentences in topic_dict.items():
if len(sentences) <= 1:
new_topic_dict[topic] = sentences
else:
sub_topics, _ = split_text_into_topics(' '.join(sentences), subtopics)
new_topic_dict[topic] = sub_topics
return new_topic_dict
# Function to convert the tree into edge data for Plotly visualization
def get_edges(tree, parent=None, level=0, top_words=None):
edges = []
labels = []
hover_texts = []
pos = {}
for key, value in tree.items():
node_label = f'Topic {key}' if parent is None else f'Subtopic {key}'
pos[node_label] = (level, len(labels))
top_word = top_words[key] if top_words and key < len(top_words) else "N/A"
labels.append(node_label)
hover_texts.append(f"Top Word: {top_word}")
if parent:
edges.append((parent, node_label))
if isinstance(value, dict):
new_edges, new_labels, new_hover_texts, new_pos = get_edges(value, node_label, level+1, top_words)
edges += new_edges
labels += new_labels
hover_texts += new_hover_texts
pos.update(new_pos)
else:
for i, sentence in enumerate(value):
sentence_label = f"{node_label} - Sentence {i+1}"
pos[sentence_label] = (level+1, len(labels))
labels.append(sentence_label)
hover_texts.append(sentence)
edges.append((node_label, sentence_label))
return edges, labels, hover_texts, pos
# Streamlit App layout
# st.title('Interactive Text Topic Tree Generator')
# Upload file
uploaded_file = st.file_uploader("Upload a text file", type="txt")
if uploaded_file is not None:
text = uploaded_file.read().decode('utf-8')
# Select number of main topics and depth of subtopics
n_topics = st.slider('Select number of main topics', 2, 10, 5)
max_depth = st.slider('Select maximum depth of subtopics', 1, 5, 2)
subtopics_per_topic = st.slider('Select number of subtopics per topic', 2, 5, 3)
# Split text into main topics and extract top words
topic_dict, top_words = split_text_into_topics(text, n_topics)
# Recursively split the topics into subtopics
full_tree = recursive_split(topic_dict, 0, max_depth, subtopics_per_topic)
# Get edges, labels, hover texts, and positions for the plot
edges, labels, hover_texts, pos = get_edges(full_tree, top_words=top_words)
# Plot the tree graph using Plotly
edge_x = []
edge_y = []
for edge in edges:
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x += [x0, x1, None]
edge_y += [y0, y1, None]
node_x = [pos[label][0] for label in labels]
node_y = [pos[label][1] for label in labels]
# Create edge trace
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=2, color='Gray'),
hoverinfo='none',
mode='lines'
)
# Create node trace with hover text showing top words
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
text=labels,
hoverinfo='text',
hovertext=hover_texts, # Adding hover text
marker=dict(
showscale=True,
colorscale='YlGnBu',
size=20,
colorbar=dict(
thickness=15,
title='Depth',
xanchor='left',
titleside='right'
),
line_width=2
)
)
# Plot the figure
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=0, l=0, r=0, t=0),
xaxis=dict(showgrid=False, zeroline=False),
yaxis=dict(showgrid=False, zeroline=False)
))
st.plotly_chart(fig) |