rockerritesh commited on
Commit
7c47743
·
verified ·
1 Parent(s): ba1c1a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -11,7 +11,8 @@ import plotly.graph_objs as go
11
 
12
  # Download the punkt tokenizer
13
  nltk.download('punkt_tab')
14
- # Helper function to split text into topics using KMeans clustering
 
15
  def split_text_into_topics(text, n_topics):
16
  sentences = sent_tokenize(text)
17
  vectorizer = TfidfVectorizer(stop_words='english')
@@ -23,10 +24,19 @@ def split_text_into_topics(text, n_topics):
23
  clusters = kmeans.labels_.tolist()
24
  topic_sentences = {i: [] for i in range(n_topics)}
25
 
 
 
 
 
 
 
 
 
 
26
  for i, sentence in enumerate(sentences):
27
  topic_sentences[clusters[i]].append(sentence)
28
 
29
- return topic_sentences
30
 
31
  # Recursive function to split subtopics
32
  def recursive_split(topic_dict, depth, max_depth, subtopics):
@@ -38,38 +48,43 @@ def recursive_split(topic_dict, depth, max_depth, subtopics):
38
  if len(sentences) <= 1:
39
  new_topic_dict[topic] = sentences
40
  else:
41
- sub_topics = split_text_into_topics(' '.join(sentences), subtopics)
42
  new_topic_dict[topic] = sub_topics
43
 
44
  return new_topic_dict
45
 
46
  # Function to convert the tree into edge data for Plotly visualization
47
- def get_edges(tree, parent=None, level=0):
48
  edges = []
49
  labels = []
 
50
  pos = {}
51
 
52
  for key, value in tree.items():
53
  node_label = f'Topic {key}' if parent is None else f'Subtopic {key}'
54
  pos[node_label] = (level, len(labels))
 
55
  labels.append(node_label)
 
56
 
57
  if parent:
58
  edges.append((parent, node_label))
59
 
60
  if isinstance(value, dict):
61
- new_edges, new_labels, new_pos = get_edges(value, node_label, level+1)
62
  edges += new_edges
63
  labels += new_labels
 
64
  pos.update(new_pos)
65
  else:
66
  for i, sentence in enumerate(value):
67
  sentence_label = f"{node_label} - Sentence {i+1}"
68
  pos[sentence_label] = (level+1, len(labels))
69
  labels.append(sentence_label)
 
70
  edges.append((node_label, sentence_label))
71
 
72
- return edges, labels, pos
73
 
74
  # Streamlit App layout
75
  st.title('Interactive Text Topic Tree Generator')
@@ -85,14 +100,14 @@ if uploaded_file is not None:
85
  max_depth = st.slider('Select maximum depth of subtopics', 1, 5, 2)
86
  subtopics_per_topic = st.slider('Select number of subtopics per topic', 2, 5, 3)
87
 
88
- # Split text into main topics
89
- topic_dict = split_text_into_topics(text, n_topics)
90
 
91
  # Recursively split the topics into subtopics
92
  full_tree = recursive_split(topic_dict, 0, max_depth, subtopics_per_topic)
93
 
94
- # Get edges and positions for the plot
95
- edges, labels, pos = get_edges(full_tree)
96
 
97
  # Plot the tree graph using Plotly
98
  edge_x = []
@@ -114,12 +129,13 @@ if uploaded_file is not None:
114
  mode='lines'
115
  )
116
 
117
- # Create node trace
118
  node_trace = go.Scatter(
119
  x=node_x, y=node_y,
120
  mode='markers+text',
121
  text=labels,
122
  hoverinfo='text',
 
123
  marker=dict(
124
  showscale=True,
125
  colorscale='YlGnBu',
 
11
 
12
  # Download the punkt tokenizer
13
  nltk.download('punkt_tab')
14
+
15
+ # Helper function to split text into topics using KMeans clustering and extract top words
16
  def split_text_into_topics(text, n_topics):
17
  sentences = sent_tokenize(text)
18
  vectorizer = TfidfVectorizer(stop_words='english')
 
24
  clusters = kmeans.labels_.tolist()
25
  topic_sentences = {i: [] for i in range(n_topics)}
26
 
27
+ # Store the top word for each cluster
28
+ top_words = []
29
+ for i in range(n_topics):
30
+ cluster_center = kmeans.cluster_centers_[i]
31
+ sorted_indices = np.argsort(cluster_center)[::-1]
32
+ top_word_index = sorted_indices[0]
33
+ top_word = vectorizer.get_feature_names_out()[top_word_index]
34
+ top_words.append(top_word)
35
+
36
  for i, sentence in enumerate(sentences):
37
  topic_sentences[clusters[i]].append(sentence)
38
 
39
+ return topic_sentences, top_words
40
 
41
  # Recursive function to split subtopics
42
  def recursive_split(topic_dict, depth, max_depth, subtopics):
 
48
  if len(sentences) <= 1:
49
  new_topic_dict[topic] = sentences
50
  else:
51
+ sub_topics, _ = split_text_into_topics(' '.join(sentences), subtopics)
52
  new_topic_dict[topic] = sub_topics
53
 
54
  return new_topic_dict
55
 
56
  # Function to convert the tree into edge data for Plotly visualization
57
+ def get_edges(tree, parent=None, level=0, top_words=None):
58
  edges = []
59
  labels = []
60
+ hover_texts = []
61
  pos = {}
62
 
63
  for key, value in tree.items():
64
  node_label = f'Topic {key}' if parent is None else f'Subtopic {key}'
65
  pos[node_label] = (level, len(labels))
66
+ top_word = top_words[key] if top_words and key < len(top_words) else "N/A"
67
  labels.append(node_label)
68
+ hover_texts.append(f"Top Word: {top_word}")
69
 
70
  if parent:
71
  edges.append((parent, node_label))
72
 
73
  if isinstance(value, dict):
74
+ new_edges, new_labels, new_hover_texts, new_pos = get_edges(value, node_label, level+1, top_words)
75
  edges += new_edges
76
  labels += new_labels
77
+ hover_texts += new_hover_texts
78
  pos.update(new_pos)
79
  else:
80
  for i, sentence in enumerate(value):
81
  sentence_label = f"{node_label} - Sentence {i+1}"
82
  pos[sentence_label] = (level+1, len(labels))
83
  labels.append(sentence_label)
84
+ hover_texts.append(sentence)
85
  edges.append((node_label, sentence_label))
86
 
87
+ return edges, labels, hover_texts, pos
88
 
89
  # Streamlit App layout
90
  st.title('Interactive Text Topic Tree Generator')
 
100
  max_depth = st.slider('Select maximum depth of subtopics', 1, 5, 2)
101
  subtopics_per_topic = st.slider('Select number of subtopics per topic', 2, 5, 3)
102
 
103
+ # Split text into main topics and extract top words
104
+ topic_dict, top_words = split_text_into_topics(text, n_topics)
105
 
106
  # Recursively split the topics into subtopics
107
  full_tree = recursive_split(topic_dict, 0, max_depth, subtopics_per_topic)
108
 
109
+ # Get edges, labels, hover texts, and positions for the plot
110
+ edges, labels, hover_texts, pos = get_edges(full_tree, top_words=top_words)
111
 
112
  # Plot the tree graph using Plotly
113
  edge_x = []
 
129
  mode='lines'
130
  )
131
 
132
+ # Create node trace with hover text showing top words
133
  node_trace = go.Scatter(
134
  x=node_x, y=node_y,
135
  mode='markers+text',
136
  text=labels,
137
  hoverinfo='text',
138
+ hovertext=hover_texts, # Adding hover text
139
  marker=dict(
140
  showscale=True,
141
  colorscale='YlGnBu',