|
|
import numpy as np |
|
|
import pandas as pd |
|
|
|
|
|
try: |
|
|
from pandas.io.formats.style import Styler |
|
|
HAS_JINJA = True |
|
|
except (ModuleNotFoundError, ImportError): |
|
|
HAS_JINJA = False |
|
|
|
|
|
|
|
|
def visualize_approximate_distribution(topic_model, |
|
|
document: str, |
|
|
topic_token_distribution: np.ndarray, |
|
|
normalize: bool = False): |
|
|
""" Visualize the topic distribution calculated by `.approximate_topic_distribution` |
|
|
on a token level. Thereby indicating the extend to which a certain word or phrases belong |
|
|
to a specific topic. The assumption here is that a single word can belong to multiple |
|
|
similar topics and as such give information about the broader set of topics within |
|
|
a single document. |
|
|
|
|
|
NOTE: |
|
|
This fuction will return a stylized pandas dataframe if Jinja2 is installed. If not, |
|
|
it will only return a pandas dataframe without color highlighting. To install jinja: |
|
|
|
|
|
`pip install jinja2` |
|
|
|
|
|
Arguments: |
|
|
topic_model: A fitted BERTopic instance. |
|
|
document: The document for which you want to visualize |
|
|
the approximated topic distribution. |
|
|
topic_token_distribution: The topic-token distribution of the document as |
|
|
extracted by `.approximate_topic_distribution` |
|
|
normalize: Whether to normalize, between 0 and 1 (summing to 1), the |
|
|
topic distribution values. |
|
|
|
|
|
Returns: |
|
|
df: A stylized dataframe indicating the best fitting topics |
|
|
for each token. |
|
|
|
|
|
Examples: |
|
|
|
|
|
```python |
|
|
# Calculate the topic distributions on a token level |
|
|
# Note that we need to have `calculate_token_level=True` |
|
|
topic_distr, topic_token_distr = topic_model.approximate_distribution( |
|
|
docs, calculate_token_level=True |
|
|
) |
|
|
|
|
|
# Visualize the approximated topic distributions |
|
|
df = topic_model.visualize_approximate_distribution(docs[0], topic_token_distr[0]) |
|
|
df |
|
|
``` |
|
|
|
|
|
To revert this stylized dataframe back to a regular dataframe, |
|
|
you can run the following: |
|
|
|
|
|
```python |
|
|
df.data.columns = [column.strip() for column in df.data.columns] |
|
|
df = df.data |
|
|
``` |
|
|
""" |
|
|
|
|
|
analyzer = topic_model.vectorizer_model.build_tokenizer() |
|
|
tokens = analyzer(document) |
|
|
|
|
|
if len(tokens) == 0: |
|
|
raise ValueError("Make sure that your document contains at least 1 token.") |
|
|
|
|
|
|
|
|
if normalize: |
|
|
df = pd.DataFrame(topic_token_distribution / topic_token_distribution.sum()).T |
|
|
else: |
|
|
df = pd.DataFrame(topic_token_distribution).T |
|
|
|
|
|
df.columns = [f"{token}_{i}" for i, token in enumerate(tokens)] |
|
|
df.columns = [f"{token}{' '*i}" for i, token in enumerate(tokens)] |
|
|
df.index = list(topic_model.topic_labels_.values())[topic_model._outliers:] |
|
|
df = df.loc[(df.sum(axis=1) != 0), :] |
|
|
|
|
|
|
|
|
def text_color(val): |
|
|
color = 'white' if val == 0 else 'black' |
|
|
return 'color: %s' % color |
|
|
|
|
|
def highligh_color(data, color='white'): |
|
|
attr = 'background-color: {}'.format(color) |
|
|
return pd.DataFrame(np.where(data == 0, attr, ''), index=data.index, columns=data.columns) |
|
|
|
|
|
if len(df) == 0: |
|
|
return df |
|
|
elif HAS_JINJA: |
|
|
df = ( |
|
|
df.style |
|
|
.format("{:.3f}") |
|
|
.background_gradient(cmap='Blues', axis=None) |
|
|
.applymap(lambda x: text_color(x)) |
|
|
.apply(highligh_color, axis=None) |
|
|
) |
|
|
return df |
|
|
|