File size: 2,228 Bytes
cbff612
2d4c953
 
 
d66308d
cbff612
d66308d
cbff612
 
 
 
 
d66308d
 
2d4c953
 
 
 
 
d66308d
fb8ff02
2d4c953
 
 
 
d66308d
 
cbff612
 
 
 
2d4c953
 
 
d371343
 
 
 
 
d66308d
d371343
 
 
 
 
 
 
 
 
 
 
 
 
 
d66308d
d371343
 
fb8ff02
 
d371343
fb8ff02
d371343
 
cbff612
 
 
d371343
 
 
 
 
 
fb8ff02
 
d66308d
 
cbff612
 
 
d371343
 
 
fb8ff02
d371343
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# ========= CONFIGURATION ==========
import pandas as pd
import plotly.express as px


title_font_size=20
title_font_color='#808393'
xaxis_title_font_size=16
yaxis_title_font_size=16


# ======== FUNCTIONS ========

def plot_histogram(df: pd.DataFrame, col_to_plot: str, bins: int, height: int = 500, title:str = None):

    plt = px.histogram(
            df,
            x=col_to_plot,
            nbins=bins,
            title=title,
            color_discrete_sequence=['#646DEF']
            )

    plt.update_layout(
            bargap=0.1,
            height=height,
            title_font_size=title_font_size,
            title_font_color=title_font_color,
            xaxis_title_font_size=xaxis_title_font_size,
            yaxis_title_font_size=yaxis_title_font_size,

            )

    return plt


# =========== TOPIC DISTRIBUTION CHART  ===========


def plot_topic_countplot(topics_df: pd.DataFrame, topic_id_col: str, topic_name_col: str, representation_col: str, height: int = 500, title:str = None):
    """
    This functions plots a count chart for Bertopic topics,
    extracting the 5 words of each topic's representation
    in order to provide more context
    """

    ## ----- Extract top 5 words ----
    topics_df['top_5_words'] = topics_df[representation_col].apply(lambda x: ", ".join(x[:5]) if isinstance(x, list) else x)

    plt = px.bar(
            topics_df,
            x=topic_id_col,
            y='Count',
            custom_data=["top_5_words", topic_name_col],
            title=title,
            )

    plt.update_xaxes(type='category')

    plt.update_traces(
        marker_color='#EF64B3',
        textposition='outside',
        hovertemplate=(
            '<b>Topic Name</b>: %{customdata[1]}<br>'
            '<b>Frequency:</b> %{y}<br>'
            '<b>Top 5 words:</b> %{customdata[0]}<extra></extra>'
            )
        )

    plt.update_layout(
        height=height,
        hoverlabel=dict(
            font_size=13,
            align="left"
        ),
        title_font_size=title_font_size,
        title_font_color=title_font_color,
        xaxis_title_font_size=xaxis_title_font_size,
        yaxis_title_font_size=yaxis_title_font_size,
    )



    return plt