TopicModelingRepo / BERTopic /bertopic /plotting /_approximate_distribution.py
kisejin's picture
Upload 261 files
19b102a verified
import numpy as np
import pandas as pd
try:
from pandas.io.formats.style import Styler
HAS_JINJA = True
except (ModuleNotFoundError, ImportError):
HAS_JINJA = False
def visualize_approximate_distribution(topic_model,
document: str,
topic_token_distribution: np.ndarray,
normalize: bool = False):
""" Visualize the topic distribution calculated by `.approximate_topic_distribution`
on a token level. Thereby indicating the extend to which a certain word or phrases belong
to a specific topic. The assumption here is that a single word can belong to multiple
similar topics and as such give information about the broader set of topics within
a single document.
NOTE:
This fuction will return a stylized pandas dataframe if Jinja2 is installed. If not,
it will only return a pandas dataframe without color highlighting. To install jinja:
`pip install jinja2`
Arguments:
topic_model: A fitted BERTopic instance.
document: The document for which you want to visualize
the approximated topic distribution.
topic_token_distribution: The topic-token distribution of the document as
extracted by `.approximate_topic_distribution`
normalize: Whether to normalize, between 0 and 1 (summing to 1), the
topic distribution values.
Returns:
df: A stylized dataframe indicating the best fitting topics
for each token.
Examples:
```python
# Calculate the topic distributions on a token level
# Note that we need to have `calculate_token_level=True`
topic_distr, topic_token_distr = topic_model.approximate_distribution(
docs, calculate_token_level=True
)
# Visualize the approximated topic distributions
df = topic_model.visualize_approximate_distribution(docs[0], topic_token_distr[0])
df
```
To revert this stylized dataframe back to a regular dataframe,
you can run the following:
```python
df.data.columns = [column.strip() for column in df.data.columns]
df = df.data
```
"""
# Tokenize document
analyzer = topic_model.vectorizer_model.build_tokenizer()
tokens = analyzer(document)
if len(tokens) == 0:
raise ValueError("Make sure that your document contains at least 1 token.")
# Prepare dataframe with results
if normalize:
df = pd.DataFrame(topic_token_distribution / topic_token_distribution.sum()).T
else:
df = pd.DataFrame(topic_token_distribution).T
df.columns = [f"{token}_{i}" for i, token in enumerate(tokens)]
df.columns = [f"{token}{' '*i}" for i, token in enumerate(tokens)]
df.index = list(topic_model.topic_labels_.values())[topic_model._outliers:]
df = df.loc[(df.sum(axis=1) != 0), :]
# Style the resulting dataframe
def text_color(val):
color = 'white' if val == 0 else 'black'
return 'color: %s' % color
def highligh_color(data, color='white'):
attr = 'background-color: {}'.format(color)
return pd.DataFrame(np.where(data == 0, attr, ''), index=data.index, columns=data.columns)
if len(df) == 0:
return df
elif HAS_JINJA:
df = (
df.style
.format("{:.3f}")
.background_gradient(cmap='Blues', axis=None)
.applymap(lambda x: text_color(x))
.apply(highligh_color, axis=None)
)
return df