File size: 2,228 Bytes
cbff612 2d4c953 d66308d cbff612 d66308d cbff612 d66308d 2d4c953 d66308d fb8ff02 2d4c953 d66308d cbff612 2d4c953 d371343 d66308d d371343 d66308d d371343 fb8ff02 d371343 fb8ff02 d371343 cbff612 d371343 fb8ff02 d66308d cbff612 d371343 fb8ff02 d371343 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
# ========= CONFIGURATION ==========
import pandas as pd
import plotly.express as px
title_font_size=20
title_font_color='#808393'
xaxis_title_font_size=16
yaxis_title_font_size=16
# ======== FUNCTIONS ========
def plot_histogram(df: pd.DataFrame, col_to_plot: str, bins: int, height: int = 500, title:str = None):
plt = px.histogram(
df,
x=col_to_plot,
nbins=bins,
title=title,
color_discrete_sequence=['#646DEF']
)
plt.update_layout(
bargap=0.1,
height=height,
title_font_size=title_font_size,
title_font_color=title_font_color,
xaxis_title_font_size=xaxis_title_font_size,
yaxis_title_font_size=yaxis_title_font_size,
)
return plt
# =========== TOPIC DISTRIBUTION CHART ===========
def plot_topic_countplot(topics_df: pd.DataFrame, topic_id_col: str, topic_name_col: str, representation_col: str, height: int = 500, title:str = None):
"""
This functions plots a count chart for Bertopic topics,
extracting the 5 words of each topic's representation
in order to provide more context
"""
## ----- Extract top 5 words ----
topics_df['top_5_words'] = topics_df[representation_col].apply(lambda x: ", ".join(x[:5]) if isinstance(x, list) else x)
plt = px.bar(
topics_df,
x=topic_id_col,
y='Count',
custom_data=["top_5_words", topic_name_col],
title=title,
)
plt.update_xaxes(type='category')
plt.update_traces(
marker_color='#EF64B3',
textposition='outside',
hovertemplate=(
'<b>Topic Name</b>: %{customdata[1]}<br>'
'<b>Frequency:</b> %{y}<br>'
'<b>Top 5 words:</b> %{customdata[0]}<extra></extra>'
)
)
plt.update_layout(
height=height,
hoverlabel=dict(
font_size=13,
align="left"
),
title_font_size=title_font_size,
title_font_color=title_font_color,
xaxis_title_font_size=xaxis_title_font_size,
yaxis_title_font_size=yaxis_title_font_size,
)
return plt
|