import numpy as np import pandas as pd try: from pandas.io.formats.style import Styler HAS_JINJA = True except (ModuleNotFoundError, ImportError): HAS_JINJA = False def visualize_approximate_distribution(topic_model, document: str, topic_token_distribution: np.ndarray, normalize: bool = False): """ Visualize the topic distribution calculated by `.approximate_topic_distribution` on a token level. Thereby indicating the extend to which a certain word or phrases belong to a specific topic. The assumption here is that a single word can belong to multiple similar topics and as such give information about the broader set of topics within a single document. NOTE: This fuction will return a stylized pandas dataframe if Jinja2 is installed. If not, it will only return a pandas dataframe without color highlighting. To install jinja: `pip install jinja2` Arguments: topic_model: A fitted BERTopic instance. document: The document for which you want to visualize the approximated topic distribution. topic_token_distribution: The topic-token distribution of the document as extracted by `.approximate_topic_distribution` normalize: Whether to normalize, between 0 and 1 (summing to 1), the topic distribution values. Returns: df: A stylized dataframe indicating the best fitting topics for each token. Examples: ```python # Calculate the topic distributions on a token level # Note that we need to have `calculate_token_level=True` topic_distr, topic_token_distr = topic_model.approximate_distribution( docs, calculate_token_level=True ) # Visualize the approximated topic distributions df = topic_model.visualize_approximate_distribution(docs[0], topic_token_distr[0]) df ``` To revert this stylized dataframe back to a regular dataframe, you can run the following: ```python df.data.columns = [column.strip() for column in df.data.columns] df = df.data ``` """ # Tokenize document analyzer = topic_model.vectorizer_model.build_tokenizer() tokens = analyzer(document) if len(tokens) == 0: raise ValueError("Make sure that your document contains at least 1 token.") # Prepare dataframe with results if normalize: df = pd.DataFrame(topic_token_distribution / topic_token_distribution.sum()).T else: df = pd.DataFrame(topic_token_distribution).T df.columns = [f"{token}_{i}" for i, token in enumerate(tokens)] df.columns = [f"{token}{' '*i}" for i, token in enumerate(tokens)] df.index = list(topic_model.topic_labels_.values())[topic_model._outliers:] df = df.loc[(df.sum(axis=1) != 0), :] # Style the resulting dataframe def text_color(val): color = 'white' if val == 0 else 'black' return 'color: %s' % color def highligh_color(data, color='white'): attr = 'background-color: {}'.format(color) return pd.DataFrame(np.where(data == 0, attr, ''), index=data.index, columns=data.columns) if len(df) == 0: return df elif HAS_JINJA: df = ( df.style .format("{:.3f}") .background_gradient(cmap='Blues', axis=None) .applymap(lambda x: text_color(x)) .apply(highligh_color, axis=None) ) return df