File size: 3,679 Bytes
19b102a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import pandas as pd

try:
    from pandas.io.formats.style import Styler
    HAS_JINJA = True
except (ModuleNotFoundError, ImportError):
    HAS_JINJA = False


def visualize_approximate_distribution(topic_model, 
                                       document: str, 
                                       topic_token_distribution: np.ndarray, 
                                       normalize: bool = False):
    """ Visualize the topic distribution calculated by `.approximate_topic_distribution` 
    on a token level. Thereby indicating the extend to which a certain word or phrases belong 
    to a specific topic. The assumption here is that a single word can belong to multiple 
    similar topics and as such give information about the broader set of topics within 
    a single document. 

    NOTE:
    This fuction will return a stylized pandas dataframe if Jinja2 is installed. If not, 
    it will only return a pandas dataframe without color highlighting. To install jinja:

    `pip install jinja2`
    
    Arguments:
        topic_model: A fitted BERTopic instance.
        document: The document for which you want to visualize 
                  the approximated topic distribution.
        topic_token_distribution: The topic-token distribution of the document as 
                                  extracted by `.approximate_topic_distribution`
        normalize: Whether to normalize, between 0 and 1 (summing to 1), the 
                   topic distribution values. 
                   
    Returns:
        df: A stylized dataframe indicating the best fitting topics
            for each token.
    
    Examples:
    
    ```python
    # Calculate the topic distributions on a token level
    # Note that we need to have `calculate_token_level=True`
    topic_distr, topic_token_distr = topic_model.approximate_distribution(
            docs, calculate_token_level=True
    )
    
    # Visualize the approximated topic distributions
    df = topic_model.visualize_approximate_distribution(docs[0], topic_token_distr[0])
    df
    ```
    
    To revert this stylized dataframe back to a regular dataframe, 
    you can run the following:
    
    ```python
    df.data.columns = [column.strip() for column in df.data.columns]
    df = df.data
    ```  
    """
    # Tokenize document
    analyzer = topic_model.vectorizer_model.build_tokenizer()
    tokens = analyzer(document)

    if len(tokens) == 0:
        raise ValueError("Make sure that your document contains at least 1 token.")
    
    # Prepare dataframe with results
    if normalize:
        df = pd.DataFrame(topic_token_distribution / topic_token_distribution.sum()).T
    else:
        df = pd.DataFrame(topic_token_distribution).T
        
    df.columns = [f"{token}_{i}" for i, token in enumerate(tokens)]
    df.columns = [f"{token}{' '*i}" for i, token in enumerate(tokens)]
    df.index = list(topic_model.topic_labels_.values())[topic_model._outliers:]
    df = df.loc[(df.sum(axis=1) != 0), :]
    
    # Style the resulting dataframe
    def text_color(val):
        color = 'white' if val == 0 else 'black'
        return 'color: %s' % color

    def highligh_color(data, color='white'):
        attr = 'background-color: {}'.format(color)
        return pd.DataFrame(np.where(data == 0, attr, ''), index=data.index, columns=data.columns)
    
    if len(df) == 0:
        return df
    elif HAS_JINJA:
        df = (
            df.style
              .format("{:.3f}")
              .background_gradient(cmap='Blues', axis=None)
              .applymap(lambda x: text_color(x))
              .apply(highligh_color, axis=None)
        )
    return df