textTree / texttree.py
rockerritesh's picture
Rename app.py to texttree.py
789d822 verified
import streamlit as st
import networkx as nx
import matplotlib.pyplot as plt
import nltk
from nltk import sent_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
import numpy as np
import plotly.graph_objs as go
# Download the punkt tokenizer
nltk.download('punkt_tab')
# Helper function to split text into topics using KMeans clustering and extract top words
def split_text_into_topics(text, n_topics):
sentences = sent_tokenize(text)
vectorizer = TfidfVectorizer(stop_words='english')
X = vectorizer.fit_transform(sentences)
kmeans = KMeans(n_clusters=n_topics, random_state=42)
kmeans.fit(X)
clusters = kmeans.labels_.tolist()
topic_sentences = {i: [] for i in range(n_topics)}
# Store the top word for each cluster
top_words = []
for i in range(n_topics):
cluster_center = kmeans.cluster_centers_[i]
sorted_indices = np.argsort(cluster_center)[::-1]
top_word_index = sorted_indices[0]
top_word = vectorizer.get_feature_names_out()[top_word_index]
top_words.append(top_word)
for i, sentence in enumerate(sentences):
topic_sentences[clusters[i]].append(sentence)
return topic_sentences, top_words
# Recursive function to split subtopics
def recursive_split(topic_dict, depth, max_depth, subtopics):
if depth >= max_depth:
return
new_topic_dict = {}
for topic, sentences in topic_dict.items():
if len(sentences) <= 1:
new_topic_dict[topic] = sentences
else:
sub_topics, _ = split_text_into_topics(' '.join(sentences), subtopics)
new_topic_dict[topic] = sub_topics
return new_topic_dict
# Function to convert the tree into edge data for Plotly visualization
def get_edges(tree, parent=None, level=0, top_words=None):
edges = []
labels = []
hover_texts = []
pos = {}
for key, value in tree.items():
node_label = f'Topic {key}' if parent is None else f'Subtopic {key}'
pos[node_label] = (level, len(labels))
top_word = top_words[key] if top_words and key < len(top_words) else "N/A"
labels.append(node_label)
hover_texts.append(f"Top Word: {top_word}")
if parent:
edges.append((parent, node_label))
if isinstance(value, dict):
new_edges, new_labels, new_hover_texts, new_pos = get_edges(value, node_label, level+1, top_words)
edges += new_edges
labels += new_labels
hover_texts += new_hover_texts
pos.update(new_pos)
else:
for i, sentence in enumerate(value):
sentence_label = f"{node_label} - Sentence {i+1}"
pos[sentence_label] = (level+1, len(labels))
labels.append(sentence_label)
hover_texts.append(sentence)
edges.append((node_label, sentence_label))
return edges, labels, hover_texts, pos
# Streamlit App layout
# st.title('Interactive Text Topic Tree Generator')
# Upload file
uploaded_file = st.file_uploader("Upload a text file", type="txt")
if uploaded_file is not None:
text = uploaded_file.read().decode('utf-8')
# Select number of main topics and depth of subtopics
n_topics = st.slider('Select number of main topics', 2, 10, 5)
max_depth = st.slider('Select maximum depth of subtopics', 1, 5, 2)
subtopics_per_topic = st.slider('Select number of subtopics per topic', 2, 5, 3)
# Split text into main topics and extract top words
topic_dict, top_words = split_text_into_topics(text, n_topics)
# Recursively split the topics into subtopics
full_tree = recursive_split(topic_dict, 0, max_depth, subtopics_per_topic)
# Get edges, labels, hover texts, and positions for the plot
edges, labels, hover_texts, pos = get_edges(full_tree, top_words=top_words)
# Plot the tree graph using Plotly
edge_x = []
edge_y = []
for edge in edges:
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x += [x0, x1, None]
edge_y += [y0, y1, None]
node_x = [pos[label][0] for label in labels]
node_y = [pos[label][1] for label in labels]
# Create edge trace
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=2, color='Gray'),
hoverinfo='none',
mode='lines'
)
# Create node trace with hover text showing top words
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
text=labels,
hoverinfo='text',
hovertext=hover_texts, # Adding hover text
marker=dict(
showscale=True,
colorscale='YlGnBu',
size=20,
colorbar=dict(
thickness=15,
title='Depth',
xanchor='left',
titleside='right'
),
line_width=2
)
)
# Plot the figure
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=0, l=0, r=0, t=0),
xaxis=dict(showgrid=False, zeroline=False),
yaxis=dict(showgrid=False, zeroline=False)
))
st.plotly_chart(fig)