rockerritesh commited on
Commit
d7e685d
·
verified ·
1 Parent(s): 1c98db8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -23
app.py CHANGED
@@ -9,7 +9,6 @@ import numpy as np
9
 
10
  # Download the punkt tokenizer
11
  nltk.download('punkt_tab')
12
-
13
  # Helper function to split text into topics using KMeans clustering
14
  def split_text_into_topics(text, n_topics):
15
  sentences = sent_tokenize(text)
@@ -42,29 +41,36 @@ def recursive_split(topic_dict, depth, max_depth, subtopics):
42
 
43
  return new_topic_dict
44
 
45
- # Plotting function to visualize the tree structure
46
- def plot_tree(tree, parent=None, graph=None, level=0):
47
- if graph is None:
48
- graph = nx.Graph()
49
-
 
50
  for key, value in tree.items():
51
  node_label = f'Topic {key}' if parent is None else f'Subtopic {key}'
52
- graph.add_node(node_label, level=level)
53
- if parent:
54
- graph.add_edge(parent, node_label)
55
 
 
 
 
56
  if isinstance(value, dict):
57
- plot_tree(value, parent=node_label, graph=graph, level=level+1)
 
 
 
58
  else:
59
  for i, sentence in enumerate(value):
60
  sentence_label = f"{node_label} - Sentence {i+1}"
61
- graph.add_node(sentence_label, level=level+1)
62
- graph.add_edge(node_label, sentence_label)
 
63
 
64
- return graph
65
 
66
  # Streamlit App layout
67
- st.title('Text Topic Tree Generator')
68
 
69
  # Upload file
70
  uploaded_file = st.file_uploader("Upload a text file", type="txt")
@@ -83,13 +89,57 @@ if uploaded_file is not None:
83
  # Recursively split the topics into subtopics
84
  full_tree = recursive_split(topic_dict, 0, max_depth, subtopics_per_topic)
85
 
86
- # Create and display the tree graph
87
- graph = plot_tree(full_tree)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Plot the tree graph
90
- pos = nx.spring_layout(graph)
91
- levels = nx.get_node_attributes(graph, 'level')
92
- plt.figure(figsize=(12, 8))
93
- nx.draw(graph, pos, with_labels=True, node_size=3000, node_color="lightblue", font_size=10, font_weight="bold", labels={node: node for node in graph.nodes()})
94
- plt.title("Tree Structure of Text Topics")
95
- st.pyplot(plt)
 
9
 
10
  # Download the punkt tokenizer
11
  nltk.download('punkt_tab')
 
12
  # Helper function to split text into topics using KMeans clustering
13
  def split_text_into_topics(text, n_topics):
14
  sentences = sent_tokenize(text)
 
41
 
42
  return new_topic_dict
43
 
44
+ # Function to convert the tree into edge data for Plotly visualization
45
+ def get_edges(tree, parent=None, level=0):
46
+ edges = []
47
+ labels = []
48
+ pos = {}
49
+
50
  for key, value in tree.items():
51
  node_label = f'Topic {key}' if parent is None else f'Subtopic {key}'
52
+ pos[node_label] = (level, len(labels))
53
+ labels.append(node_label)
 
54
 
55
+ if parent:
56
+ edges.append((parent, node_label))
57
+
58
  if isinstance(value, dict):
59
+ new_edges, new_labels, new_pos = get_edges(value, node_label, level+1)
60
+ edges += new_edges
61
+ labels += new_labels
62
+ pos.update(new_pos)
63
  else:
64
  for i, sentence in enumerate(value):
65
  sentence_label = f"{node_label} - Sentence {i+1}"
66
+ pos[sentence_label] = (level+1, len(labels))
67
+ labels.append(sentence_label)
68
+ edges.append((node_label, sentence_label))
69
 
70
+ return edges, labels, pos
71
 
72
  # Streamlit App layout
73
+ st.title('Interactive Text Topic Tree Generator')
74
 
75
  # Upload file
76
  uploaded_file = st.file_uploader("Upload a text file", type="txt")
 
89
  # Recursively split the topics into subtopics
90
  full_tree = recursive_split(topic_dict, 0, max_depth, subtopics_per_topic)
91
 
92
+ # Get edges and positions for the plot
93
+ edges, labels, pos = get_edges(full_tree)
94
+
95
+ # Plot the tree graph using Plotly
96
+ edge_x = []
97
+ edge_y = []
98
+ for edge in edges:
99
+ x0, y0 = pos[edge[0]]
100
+ x1, y1 = pos[edge[1]]
101
+ edge_x += [x0, x1, None]
102
+ edge_y += [y0, y1, None]
103
+
104
+ node_x = [pos[label][0] for label in labels]
105
+ node_y = [pos[label][1] for label in labels]
106
+
107
+ # Create edge trace
108
+ edge_trace = go.Scatter(
109
+ x=edge_x, y=edge_y,
110
+ line=dict(width=2, color='Gray'),
111
+ hoverinfo='none',
112
+ mode='lines'
113
+ )
114
+
115
+ # Create node trace
116
+ node_trace = go.Scatter(
117
+ x=node_x, y=node_y,
118
+ mode='markers+text',
119
+ text=labels,
120
+ hoverinfo='text',
121
+ marker=dict(
122
+ showscale=True,
123
+ colorscale='YlGnBu',
124
+ size=20,
125
+ colorbar=dict(
126
+ thickness=15,
127
+ title='Depth',
128
+ xanchor='left',
129
+ titleside='right'
130
+ ),
131
+ line_width=2
132
+ )
133
+ )
134
+
135
+ # Plot the figure
136
+ fig = go.Figure(data=[edge_trace, node_trace],
137
+ layout=go.Layout(
138
+ showlegend=False,
139
+ hovermode='closest',
140
+ margin=dict(b=0, l=0, r=0, t=0),
141
+ xaxis=dict(showgrid=False, zeroline=False),
142
+ yaxis=dict(showgrid=False, zeroline=False)
143
+ ))
144
 
145
+ st.plotly_chart(fig)