afanyu237 commited on
Commit
6cc22fb
·
verified ·
1 Parent(s): bb7b43a

Update helper.py

Browse files
Files changed (1) hide show
  1. helper.py +143 -59
helper.py CHANGED
@@ -5,9 +5,17 @@ from collections import Counter
5
  import emoji
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
- from collections import Counter
9
  import plotly.express as px
 
 
10
 
 
 
 
 
 
 
 
11
 
12
  extract = URLExtract()
13
 
@@ -25,7 +33,7 @@ def fetch_stats(selected_user,df):
25
  words.extend(message.split())
26
 
27
  # fetch number of media messages
28
- num_media_messages = df[df['unfiltered_messages'] == '<media omitted>\n'].shape[0]
29
 
30
  # fetch number of links shared
31
  links = []
@@ -141,42 +149,127 @@ def activity_heatmap(selected_user,df):
141
  user_heatmap = df.pivot_table(index='day', columns='period', values='message', aggfunc='count').fillna(0)
142
 
143
  return user_heatmap
 
144
  def generate_wordcloud(text, color):
145
  wordcloud = WordCloud(width=400, height=300, background_color=color, colormap="viridis").generate(text)
146
  return wordcloud
147
 
148
- # def plot_topics(topics):
149
- # """
150
- # Plots a bar chart for the top words in each topic.
151
- # """
152
- # if not topics or not isinstance(topics[0], list):
153
- # raise ValueError("topics must be a list of lists of words.")
154
-
155
- # print("Topics received:", topics) # Debugging
156
- # fig, axes = plt.subplots(1, len(topics), figsize=(20, 10))
157
- # if len(topics) == 1:
158
- # axes = [axes] # Ensure axes is iterable for single topic
159
-
160
- # for idx, topic in enumerate(topics):
161
- # if not isinstance(topic, list):
162
- # raise ValueError(f"Topic {idx} is not a list of words.")
163
-
164
- # top_words = topic
165
- # print(f"Top words for Topic {idx}: {top_words}") # Debugging
166
- # axes[idx].barh(top_words, range(len(top_words)))
167
- # axes[idx].set_title(f"Topic {idx}")
168
- # axes[idx].set_xlabel("Word Importance")
169
- # axes[idx].set_ylabel("Top Words")
170
-
171
- # plt.tight_layout()
172
- # return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  def plot_topic_distribution(df):
174
  """
175
  Plots the distribution of topics in the chat data.
176
  """
177
  topic_counts = df['topic'].value_counts().sort_index()
178
  fig, ax = plt.subplots()
179
- sns.barplot(x=topic_counts.index, y=topic_counts.values, ax=ax, palette="viridis")
180
  ax.set_title("Topic Distribution")
181
  ax.set_xlabel("Topic")
182
  ax.set_ylabel("Number of Messages")
@@ -189,6 +282,16 @@ def most_frequent_keywords(messages, top_n=10):
189
  words = [word for msg in messages for word in msg.split()]
190
  word_freq = Counter(words)
191
  return word_freq.most_common(top_n)
 
 
 
 
 
 
 
 
 
 
192
  def plot_topic_distribution_over_time(topic_distribution):
193
  """
194
  Plots the distribution of topics over time using a line chart.
@@ -213,37 +316,11 @@ def plot_most_frequent_keywords(keywords):
213
  """
214
  words, counts = zip(*keywords)
215
  fig, ax = plt.subplots()
216
- sns.barplot(x=list(counts), y=list(words), ax=ax, palette="viridis")
217
  ax.set_title("Most Frequent Keywords")
218
  ax.set_xlabel("Frequency")
219
  ax.set_ylabel("Keyword")
220
  return fig
221
- def topic_distribution_over_time(df, time_freq='M'):
222
- """
223
- Analyzes the distribution of topics over time.
224
- """
225
- # Group by time interval and topic
226
- df['time_period'] = df['date'].dt.to_period(time_freq)
227
- topic_distribution = df.groupby(['time_period', 'topic']).size().unstack(fill_value=0)
228
- return topic_distribution
229
-
230
- def plot_topic_distribution_over_time(topic_distribution):
231
- """
232
- Plots the distribution of topics over time using a line chart.
233
- """
234
- fig, ax = plt.subplots(figsize=(12, 6))
235
-
236
- # Plot each topic as a separate line
237
- for topic in topic_distribution.columns:
238
- ax.plot(topic_distribution.index.to_timestamp(), topic_distribution[topic], label=f"Topic {topic}")
239
-
240
- ax.set_title("Topic Distribution Over Time")
241
- ax.set_xlabel("Time Period")
242
- ax.set_ylabel("Number of Messages")
243
- ax.legend(title="Topics", bbox_to_anchor=(1.05, 1), loc='upper left')
244
- plt.xticks(rotation=45)
245
- plt.tight_layout()
246
- return fig
247
 
248
  def plot_topic_distribution_over_time_plotly(topic_distribution):
249
  """
@@ -257,6 +334,7 @@ def plot_topic_distribution_over_time_plotly(topic_distribution):
257
  title="Topic Distribution Over Time", labels={'time_period': 'Time Period', 'count': 'Number of Messages'})
258
  fig.update_layout(legend_title_text='Topics', xaxis_tickangle=-45)
259
  return fig
 
260
  def plot_clusters(reduced_features, clusters):
261
  """
262
  Visualize clusters using t-SNE.
@@ -279,19 +357,25 @@ def plot_clusters(reduced_features, clusters):
279
  plt.ylabel("t-SNE Component 2")
280
  plt.tight_layout()
281
  return plt.gcf()
 
 
 
 
 
282
  def get_cluster_labels(df, n_clusters):
283
  """
284
  Generate descriptive labels for each cluster based on top keywords.
285
  """
286
- from sklearn.feature_extraction.text import TfidfVectorizer
287
- import numpy as np
288
-
289
  vectorizer = TfidfVectorizer(max_features=5000, stop_words='english')
290
  tfidf_matrix = vectorizer.fit_transform(df['lemmatized_message'])
291
 
292
  cluster_labels = {}
 
 
 
293
  for cluster_id in range(n_clusters):
294
- cluster_indices = df[df['cluster'] == cluster_id].index
 
295
  if len(cluster_indices) > 0:
296
  cluster_tfidf = tfidf_matrix[cluster_indices]
297
  top_keywords = np.argsort(cluster_tfidf.sum(axis=0).A1)[-3:][::-1]
 
5
  import emoji
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
 
8
  import plotly.express as px
9
+ import numpy as np
10
+ from sklearn.feature_extraction.text import TfidfVectorizer
11
 
12
+ # Import the AI topic titles module
13
+ try:
14
+ from ai_topic_titles import generate_topic_titles
15
+ AI_TOPIC_TITLES_AVAILABLE = True
16
+ except ImportError:
17
+ AI_TOPIC_TITLES_AVAILABLE = False
18
+ print("Note: ai_topic_titles module not found. Using basic topic titles.")
19
 
20
  extract = URLExtract()
21
 
 
33
  words.extend(message.split())
34
 
35
  # fetch number of media messages
36
+ num_media_messages = df[df['unfiltered_messages'].str.contains('<media omitted>', case=False, na=False)].shape[0]
37
 
38
  # fetch number of links shared
39
  links = []
 
149
  user_heatmap = df.pivot_table(index='day', columns='period', values='message', aggfunc='count').fillna(0)
150
 
151
  return user_heatmap
152
+
153
  def generate_wordcloud(text, color):
154
  wordcloud = WordCloud(width=400, height=300, background_color=color, colormap="viridis").generate(text)
155
  return wordcloud
156
 
157
+ def create_heuristic_title(topic, idx):
158
+ """
159
+ Generates a simple title based on the top words in the topic.
160
+ """
161
+ return f"Topic {idx + 1}: {', '.join(topic[:3])}"
162
+
163
+ def generate_topic_titles_wrapper(topics, hf_token=None, use_ai=True, **kwargs):
164
+ """
165
+ Generate titles for topics using AI or basic method.
166
+
167
+ Args:
168
+ topics (list): List of topics, where each topic is a list of words
169
+ hf_token (str, optional): Hugging Face token for AI mode
170
+ use_ai (bool): Whether to use AI (default: True if available)
171
+ **kwargs: Additional parameters for AI function
172
+
173
+ Returns:
174
+ list: List of topic titles
175
+ """
176
+ if use_ai and AI_TOPIC_TITLES_AVAILABLE:
177
+ try:
178
+ # Use AI with Hugging Face
179
+ api_type = "huggingface" if hf_token else "local"
180
+
181
+ # Get AI-generated titles
182
+ titles = generate_topic_titles(
183
+ topics=topics,
184
+ api_type=api_type,
185
+ hf_token=hf_token,
186
+ **kwargs
187
+ )
188
+
189
+ print("AI-generated topic titles:")
190
+ for title in titles:
191
+ print(f" {title}")
192
+ return titles
193
+
194
+ except Exception as e:
195
+ print(f"AI topic title generation failed: {e}")
196
+ print("Falling back to basic titles...")
197
+ # Fall through to basic method
198
+
199
+ # Basic method (fallback)
200
+ titles = []
201
+ for idx, topic in enumerate(topics):
202
+ if isinstance(topic, list) and len(topic) >= 3:
203
+ # Create title from first 3 words
204
+ title = f"Topic {idx + 1}: {', '.join(topic[:3])}"
205
+ else:
206
+ title = f"Topic {idx + 1}: General Discussion"
207
+ titles.append(title)
208
+ print("THESE ARE THE TOPICS TITLES: ", title)
209
+ return titles
210
+
211
+ # Keep the old function for backward compatibility
212
+ def generate_topic_titles(topics, hf_token=None, **kwargs):
213
+ """
214
+ Generate titles for topics based on their top words.
215
+
216
+ Args:
217
+ topics (list): List of topics, where each topic is a list of words
218
+ hf_token (str, optional): Hugging Face token for AI mode
219
+ **kwargs: Additional parameters for AI function
220
+
221
+ Returns:
222
+ list: List of topic titles
223
+ """
224
+ # By default, try to use AI if available
225
+ use_ai = kwargs.pop('use_ai', True)
226
+ return generate_topic_titles_wrapper(topics, hf_token, use_ai, **kwargs)
227
+
228
+ def plot_topics(topics, use_ai=True, hf_token=None, **kwargs):
229
+ """
230
+ Plots a bar chart for the top words in each topic.
231
+
232
+ Args:
233
+ topics: List of topics
234
+ use_ai: Whether to use AI for titles
235
+ hf_token: Hugging Face token for AI
236
+ **kwargs: Additional parameters for AI
237
+
238
+ Returns:
239
+ matplotlib.figure.Figure: The plot figure
240
+ """
241
+ if not topics or not isinstance(topics[0], list):
242
+ raise ValueError("topics must be a list of lists of words.")
243
+
244
+ # Generate titles using the wrapper
245
+ titles = generate_topic_titles_wrapper(topics, hf_token=hf_token, use_ai=use_ai, **kwargs)
246
+
247
+ fig, axes = plt.subplots(1, len(topics), figsize=(20, 10))
248
+ if len(topics) == 1:
249
+ axes = [axes] # Ensure axes is iterable for single topic
250
+
251
+ for idx, topic in enumerate(topics):
252
+ if not isinstance(topic, list):
253
+ raise ValueError(f"Topic {idx} is not a list of words.")
254
+
255
+ top_words = topic[:10] # Show top 10 words
256
+ axes[idx].barh(range(len(top_words)), range(len(top_words)))
257
+ axes[idx].set_yticks(range(len(top_words)))
258
+ axes[idx].set_yticklabels(top_words)
259
+ axes[idx].set_title(titles[idx], fontsize=14, fontweight='bold')
260
+ axes[idx].set_xlabel("Word Importance")
261
+ axes[idx].set_ylabel("Top Words")
262
+
263
+ plt.tight_layout()
264
+ return fig
265
+
266
  def plot_topic_distribution(df):
267
  """
268
  Plots the distribution of topics in the chat data.
269
  """
270
  topic_counts = df['topic'].value_counts().sort_index()
271
  fig, ax = plt.subplots()
272
+ sns.barplot(x=topic_counts.index, y=topic_counts.values, ax=ax, palette="viridis", hue=topic_counts.index, legend=False)
273
  ax.set_title("Topic Distribution")
274
  ax.set_xlabel("Topic")
275
  ax.set_ylabel("Number of Messages")
 
282
  words = [word for msg in messages for word in msg.split()]
283
  word_freq = Counter(words)
284
  return word_freq.most_common(top_n)
285
+
286
+ def topic_distribution_over_time(df, time_freq='M'):
287
+ """
288
+ Analyzes the distribution of topics over time.
289
+ """
290
+ # Group by time interval and topic
291
+ df['time_period'] = df['date'].dt.to_period(time_freq)
292
+ topic_distribution = df.groupby(['time_period', 'topic']).size().unstack(fill_value=0)
293
+ return topic_distribution
294
+
295
  def plot_topic_distribution_over_time(topic_distribution):
296
  """
297
  Plots the distribution of topics over time using a line chart.
 
316
  """
317
  words, counts = zip(*keywords)
318
  fig, ax = plt.subplots()
319
+ sns.barplot(x=list(counts), y=list(words), ax=ax, palette="viridis", hue=list(words), legend=False)
320
  ax.set_title("Most Frequent Keywords")
321
  ax.set_xlabel("Frequency")
322
  ax.set_ylabel("Keyword")
323
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
  def plot_topic_distribution_over_time_plotly(topic_distribution):
326
  """
 
334
  title="Topic Distribution Over Time", labels={'time_period': 'Time Period', 'count': 'Number of Messages'})
335
  fig.update_layout(legend_title_text='Topics', xaxis_tickangle=-45)
336
  return fig
337
+
338
  def plot_clusters(reduced_features, clusters):
339
  """
340
  Visualize clusters using t-SNE.
 
357
  plt.ylabel("t-SNE Component 2")
358
  plt.tight_layout()
359
  return plt.gcf()
360
+
361
+ def remove_emojis(text):
362
+ """Removes emojis from text to prevent matplotlib warnings."""
363
+ return text.encode('ascii', 'ignore').decode('ascii')
364
+
365
  def get_cluster_labels(df, n_clusters):
366
  """
367
  Generate descriptive labels for each cluster based on top keywords.
368
  """
 
 
 
369
  vectorizer = TfidfVectorizer(max_features=5000, stop_words='english')
370
  tfidf_matrix = vectorizer.fit_transform(df['lemmatized_message'])
371
 
372
  cluster_labels = {}
373
+ # Reset index to ensure alignment with tfidf_matrix
374
+ df_reset = df.reset_index(drop=True)
375
+
376
  for cluster_id in range(n_clusters):
377
+ # Get indices where cluster matches
378
+ cluster_indices = df_reset[df_reset['cluster'] == cluster_id].index
379
  if len(cluster_indices) > 0:
380
  cluster_tfidf = tfidf_matrix[cluster_indices]
381
  top_keywords = np.argsort(cluster_tfidf.sum(axis=0).A1)[-3:][::-1]