File size: 5,492 Bytes
5d0296f
 
 
588731e
5d0296f
 
 
 
ba1c1a9
 
5d0296f
588731e
1c98db8
7c47743
 
5d0296f
 
 
 
 
 
 
 
 
 
 
7c47743
 
 
 
 
 
 
 
 
5d0296f
 
 
7c47743
5d0296f
 
 
 
 
 
 
 
 
 
 
7c47743
5d0296f
 
 
 
d7e685d
7c47743
d7e685d
 
7c47743
d7e685d
 
5d0296f
 
d7e685d
7c47743
d7e685d
7c47743
5d0296f
d7e685d
 
 
5d0296f
7c47743
d7e685d
 
7c47743
d7e685d
5d0296f
 
 
d7e685d
 
7c47743
d7e685d
5d0296f
7c47743
5d0296f
 
789d822
5d0296f
 
 
 
 
 
 
 
 
 
 
 
7c47743
 
5d0296f
 
 
 
7c47743
 
d7e685d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c47743
d7e685d
 
 
 
 
7c47743
d7e685d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d0296f
d7e685d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import streamlit as st
import networkx as nx
import matplotlib.pyplot as plt
import nltk
from nltk import sent_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
import numpy as np
import plotly.graph_objs as go


# Download the punkt tokenizer
nltk.download('punkt_tab')

# Helper function to split text into topics using KMeans clustering and extract top words
def split_text_into_topics(text, n_topics):
    sentences = sent_tokenize(text)
    vectorizer = TfidfVectorizer(stop_words='english')
    X = vectorizer.fit_transform(sentences)
    
    kmeans = KMeans(n_clusters=n_topics, random_state=42)
    kmeans.fit(X)
    
    clusters = kmeans.labels_.tolist()
    topic_sentences = {i: [] for i in range(n_topics)}
    
    # Store the top word for each cluster
    top_words = []
    for i in range(n_topics):
        cluster_center = kmeans.cluster_centers_[i]
        sorted_indices = np.argsort(cluster_center)[::-1]
        top_word_index = sorted_indices[0]
        top_word = vectorizer.get_feature_names_out()[top_word_index]
        top_words.append(top_word)
    
    for i, sentence in enumerate(sentences):
        topic_sentences[clusters[i]].append(sentence)
        
    return topic_sentences, top_words

# Recursive function to split subtopics
def recursive_split(topic_dict, depth, max_depth, subtopics):
    if depth >= max_depth:
        return
    
    new_topic_dict = {}
    for topic, sentences in topic_dict.items():
        if len(sentences) <= 1:
            new_topic_dict[topic] = sentences
        else:
            sub_topics, _ = split_text_into_topics(' '.join(sentences), subtopics)
            new_topic_dict[topic] = sub_topics
            
    return new_topic_dict

# Function to convert the tree into edge data for Plotly visualization
def get_edges(tree, parent=None, level=0, top_words=None):
    edges = []
    labels = []
    hover_texts = []
    pos = {}

    for key, value in tree.items():
        node_label = f'Topic {key}' if parent is None else f'Subtopic {key}'
        pos[node_label] = (level, len(labels))
        top_word = top_words[key] if top_words and key < len(top_words) else "N/A"
        labels.append(node_label)
        hover_texts.append(f"Top Word: {top_word}")
        
        if parent:
            edges.append((parent, node_label))

        if isinstance(value, dict):
            new_edges, new_labels, new_hover_texts, new_pos = get_edges(value, node_label, level+1, top_words)
            edges += new_edges
            labels += new_labels
            hover_texts += new_hover_texts
            pos.update(new_pos)
        else:
            for i, sentence in enumerate(value):
                sentence_label = f"{node_label} - Sentence {i+1}"
                pos[sentence_label] = (level+1, len(labels))
                labels.append(sentence_label)
                hover_texts.append(sentence)
                edges.append((node_label, sentence_label))
    
    return edges, labels, hover_texts, pos

# Streamlit App layout
# st.title('Interactive Text Topic Tree Generator')

# Upload file
uploaded_file = st.file_uploader("Upload a text file", type="txt")

if uploaded_file is not None:
    text = uploaded_file.read().decode('utf-8')
    
    # Select number of main topics and depth of subtopics
    n_topics = st.slider('Select number of main topics', 2, 10, 5)
    max_depth = st.slider('Select maximum depth of subtopics', 1, 5, 2)
    subtopics_per_topic = st.slider('Select number of subtopics per topic', 2, 5, 3)
    
    # Split text into main topics and extract top words
    topic_dict, top_words = split_text_into_topics(text, n_topics)
    
    # Recursively split the topics into subtopics
    full_tree = recursive_split(topic_dict, 0, max_depth, subtopics_per_topic)
    
    # Get edges, labels, hover texts, and positions for the plot
    edges, labels, hover_texts, pos = get_edges(full_tree, top_words=top_words)
    
    # Plot the tree graph using Plotly
    edge_x = []
    edge_y = []
    for edge in edges:
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x += [x0, x1, None]
        edge_y += [y0, y1, None]
    
    node_x = [pos[label][0] for label in labels]
    node_y = [pos[label][1] for label in labels]
    
    # Create edge trace
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=2, color='Gray'),
        hoverinfo='none',
        mode='lines'
    )
    
    # Create node trace with hover text showing top words
    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        text=labels,
        hoverinfo='text',
        hovertext=hover_texts,  # Adding hover text
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            size=20,
            colorbar=dict(
                thickness=15,
                title='Depth',
                xanchor='left',
                titleside='right'
            ),
            line_width=2
        )
    )
    
    # Plot the figure
    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=0, l=0, r=0, t=0),
                        xaxis=dict(showgrid=False, zeroline=False),
                        yaxis=dict(showgrid=False, zeroline=False)
                    ))
    
    st.plotly_chart(fig)