lynn-twinkl commited on
Commit
d371343
·
1 Parent(s): 5c6f9ad

added: topic count plot

Browse files
Files changed (1) hide show
  1. src/px_charts.py +46 -1
src/px_charts.py CHANGED
@@ -1,12 +1,13 @@
1
  import pandas as pd
2
  import plotly.express as px
3
 
4
- def plot_hist(df: pd.DataFrame, col_to_plot: str, bins: int, height: int = 500):
5
 
6
  plt = px.histogram(
7
  df,
8
  x=col_to_plot,
9
  nbins=bins,
 
10
  color_discrete_sequence=['#646DEF']
11
  )
12
 
@@ -16,3 +17,47 @@ def plot_hist(df: pd.DataFrame, col_to_plot: str, bins: int, height: int = 500):
16
  )
17
 
18
  return plt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import plotly.express as px
3
 
4
+ def plot_histogram(df: pd.DataFrame, col_to_plot: str, bins: int, height: int = 500):
5
 
6
  plt = px.histogram(
7
  df,
8
  x=col_to_plot,
9
  nbins=bins,
10
+ title=None,
11
  color_discrete_sequence=['#646DEF']
12
  )
13
 
 
17
  )
18
 
19
  return plt
20
+
21
+
22
+ # =========== TOPIC DISTRIBUTION CHART ===========
23
+
24
+
25
+ def plot_topic_countplot(topics_df: pd.DataFrame, topic_id_col: str, topic_name_col: str, representation_col: str, height: int = 500):
26
+ """
27
+ This functions plots a count chart for Bertopic topics,
28
+ extracting the 5 words of each topic's representation
29
+ in order to provide more context
30
+ """
31
+
32
+ ## ----- Extract top 5 words ----
33
+ topics_df['top_5_words'] = topics_df[representation_col].apply(lambda x: ", ".join(x[:5]) if isinstance(x, list) else x)
34
+
35
+ plt = px.bar(
36
+ topics_df,
37
+ x=topic_id_col,
38
+ y='Count',
39
+ custom_data=["top_5_words", topic_name_col],
40
+ title=None,
41
+ color_discrete_sequence=['#FF5733']
42
+ )
43
+
44
+ plt.update_traces(
45
+ marker_color='#646DEF',
46
+ textposition='outside',
47
+ hovertemplate=(
48
+ 'Topic Name: %{customdata[1]}<br>'
49
+ 'Frequency: %{y}<br>'
50
+ 'Top 5 words: %{customdata[0]}<extra></extra>'
51
+ )
52
+ )
53
+
54
+ plt.update_layout(
55
+ uniformtext_minsize=10,
56
+ height=height,
57
+ hoverlabel=dict(
58
+ font_size=14
59
+ )
60
+ )
61
+
62
+
63
+ return plt