import numpy as np import pandas as pd from typing import Callable, List, Union from scipy.sparse import csr_matrix from scipy.cluster import hierarchy as sch from scipy.spatial.distance import squareform from sklearn.metrics.pairwise import cosine_similarity import plotly.graph_objects as go import plotly.figure_factory as ff from bertopic._utils import validate_distance_matrix def visualize_hierarchy(topic_model, orientation: str = "left", topics: List[int] = None, top_n_topics: int = None, custom_labels: Union[bool, str] = False, title: str = "Hierarchical Clustering", width: int = 1000, height: int = 600, hierarchical_topics: pd.DataFrame = None, linkage_function: Callable[[csr_matrix], np.ndarray] = None, distance_function: Callable[[csr_matrix], csr_matrix] = None, color_threshold: int = 1) -> go.Figure: """ Visualize a hierarchical structure of the topics A ward linkage function is used to perform the hierarchical clustering based on the cosine distance matrix between topic embeddings. Arguments: topic_model: A fitted BERTopic instance. orientation: The orientation of the figure. Either 'left' or 'bottom' topics: A selection of topics to visualize top_n_topics: Only select the top n most frequent topics custom_labels: If bool, whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. If `str`, it uses labels from other aspects, e.g., "Aspect1". NOTE: Custom labels are only generated for the original un-merged topics. title: Title of the plot. width: The width of the figure. Only works if orientation is set to 'left' height: The height of the figure. Only works if orientation is set to 'bottom' hierarchical_topics: A dataframe that contains a hierarchy of topics represented by their parents and their children. NOTE: The hierarchical topic names are only visualized if both `topics` and `top_n_topics` are not set. linkage_function: The linkage function to use. Default is: `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` NOTE: Make sure to use the same `linkage_function` as used in `topic_model.hierarchical_topics`. distance_function: The distance function to use on the c-TF-IDF matrix. Default is: `lambda x: 1 - cosine_similarity(x)`. You can pass any function that returns either a square matrix of shape (n_samples, n_samples) with zeros on the diagonal and non-negative values or condensed distance matrix of shape (n_samples * (n_samples - 1) / 2,) containing the upper triangular of the distance matrix. NOTE: Make sure to use the same `distance_function` as used in `topic_model.hierarchical_topics`. color_threshold: Value at which the separation of clusters will be made which will result in different colors for different clusters. A higher value will typically lead in less colored clusters. Returns: fig: A plotly figure Examples: To visualize the hierarchical structure of topics simply run: ```python topic_model.visualize_hierarchy() ``` If you also want the labels visualized of hierarchical topics, run the following: ```python # Extract hierarchical topics and their representations hierarchical_topics = topic_model.hierarchical_topics(docs) # Visualize these representations topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) ``` If you want to save the resulting figure: ```python fig = topic_model.visualize_hierarchy() fig.write_html("path/to/file.html") ``` """ if distance_function is None: distance_function = lambda x: 1 - cosine_similarity(x) if linkage_function is None: linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) # Select topics based on top_n and topics args freq_df = topic_model.get_topic_freq() freq_df = freq_df.loc[freq_df.Topic != -1, :] if topics is not None: topics = list(topics) elif top_n_topics is not None: topics = sorted(freq_df.Topic.to_list()[:top_n_topics]) else: topics = sorted(freq_df.Topic.to_list()) # Select embeddings all_topics = sorted(list(topic_model.get_topics().keys())) indices = np.array([all_topics.index(topic) for topic in topics]) # Select topic embeddings if topic_model.c_tf_idf_ is not None: embeddings = topic_model.c_tf_idf_[indices] else: embeddings = np.array(topic_model.topic_embeddings_)[indices] # Annotations if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()): annotations = _get_annotations(topic_model=topic_model, hierarchical_topics=hierarchical_topics, embeddings=embeddings, distance_function=distance_function, linkage_function=linkage_function, orientation=orientation, custom_labels=custom_labels) else: annotations = None # wrap distance function to validate input and return a condensed distance matrix distance_function_viz = lambda x: validate_distance_matrix( distance_function(x), embeddings.shape[0]) # Create dendogram fig = ff.create_dendrogram(embeddings, orientation=orientation, distfun=distance_function_viz, linkagefun=linkage_function, hovertext=annotations, color_threshold=color_threshold) # Create nicer labels axis = "yaxis" if orientation == "left" else "xaxis" if isinstance(custom_labels, str): new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]] new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] elif topic_model.custom_labels_ is not None and custom_labels: new_labels = [topic_model.custom_labels_[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]] else: new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) for x in fig.layout[axis]["ticktext"]] new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] # Stylize layout fig.update_layout( plot_bgcolor='#ECEFF1', template="plotly_white", title={ 'text': f"{title}", 'x': 0.5, 'xanchor': 'center', 'yanchor': 'top', 'font': dict( size=22, color="Black") }, hoverlabel=dict( bgcolor="white", font_size=16, font_family="Rockwell" ), ) # Stylize orientation if orientation == "left": fig.update_layout(height=200 + (15 * len(topics)), width=width, yaxis=dict(tickmode="array", ticktext=new_labels)) # Fix empty space on the bottom of the graph y_max = max([trace['y'].max() + 5 for trace in fig['data']]) y_min = min([trace['y'].min() - 5 for trace in fig['data']]) fig.update_layout(yaxis=dict(range=[y_min, y_max])) else: fig.update_layout(width=200 + (15 * len(topics)), height=height, xaxis=dict(tickmode="array", ticktext=new_labels)) if hierarchical_topics is not None: for index in [0, 3]: axis = "x" if orientation == "left" else "y" xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black', hovertext=hovertext, hoverinfo="text", mode='markers', showlegend=False)) return fig def _get_annotations(topic_model, hierarchical_topics: pd.DataFrame, embeddings: csr_matrix, linkage_function: Callable[[csr_matrix], np.ndarray], distance_function: Callable[[csr_matrix], csr_matrix], orientation: str, custom_labels: bool = False) -> List[List[str]]: """ Get annotations by replicating linkage function calculation in scipy Arguments topic_model: A fitted BERTopic instance. hierarchical_topics: A dataframe that contains a hierarchy of topics represented by their parents and their children. NOTE: The hierarchical topic names are only visualized if both `topics` and `top_n_topics` are not set. embeddings: The c-TF-IDF matrix on which to model the hierarchy linkage_function: The linkage function to use. Default is: `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` NOTE: Make sure to use the same `linkage_function` as used in `topic_model.hierarchical_topics`. distance_function: The distance function to use on the c-TF-IDF matrix. Default is: `lambda x: 1 - cosine_similarity(x)`. You can pass any function that returns either a square matrix of shape (n_samples, n_samples) with zeros on the diagonal and non-negative values or condensed distance matrix of shape (n_samples * (n_samples - 1) / 2,) containing the upper triangular of the distance matrix. NOTE: Make sure to use the same `distance_function` as used in `topic_model.hierarchical_topics`. orientation: The orientation of the figure. Either 'left' or 'bottom' custom_labels: Whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. NOTE: Custom labels are only generated for the original un-merged topics. Returns: text_annotations: Annotations to be used within Plotly's `ff.create_dendogram` """ df = hierarchical_topics.loc[hierarchical_topics.Parent_Name != "Top", :] # Calculate distance X = distance_function(embeddings) X = validate_distance_matrix(X, embeddings.shape[0]) # Calculate linkage and generate dendrogram Z = linkage_function(X) P = sch.dendrogram(Z, orientation=orientation, no_plot=True) # store topic no.(leaves) corresponding to the x-ticks in dendrogram x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10) x_topic = dict(zip(P['leaves'], x_ticks)) topic_vals = dict() for key, val in x_topic.items(): topic_vals[val] = [key] parent_topic = dict(zip(df.Parent_ID, df.Topics)) # loop through every trace (scatter plot) in dendrogram text_annotations = [] for index, trace in enumerate(P['icoord']): fst_topic = topic_vals[trace[0]] scnd_topic = topic_vals[trace[2]] if len(fst_topic) == 1: if isinstance(custom_labels, str): fst_name = f"{fst_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][fst_topic[0]]))[0][:3]) elif topic_model.custom_labels_ is not None and custom_labels: fst_name = topic_model.custom_labels_[fst_topic[0] + topic_model._outliers] else: fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5]) else: for key, value in parent_topic.items(): if set(value) == set(fst_topic): fst_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] if len(scnd_topic) == 1: if isinstance(custom_labels, str): scnd_name = f"{scnd_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][scnd_topic[0]]))[0][:3]) elif topic_model.custom_labels_ is not None and custom_labels: scnd_name = topic_model.custom_labels_[scnd_topic[0] + topic_model._outliers] else: scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5]) else: for key, value in parent_topic.items(): if set(value) == set(scnd_topic): scnd_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] text_annotations.append([fst_name, "", "", scnd_name]) center = (trace[0] + trace[2]) / 2 topic_vals[center] = fst_topic + scnd_topic return text_annotations