| | import numpy as np |
| | import pandas as pd |
| | from typing import Callable, List, Union |
| | from scipy.sparse import csr_matrix |
| | from scipy.cluster import hierarchy as sch |
| | from scipy.spatial.distance import squareform |
| | from sklearn.metrics.pairwise import cosine_similarity |
| |
|
| | import plotly.graph_objects as go |
| | import plotly.figure_factory as ff |
| |
|
| | from bertopic._utils import validate_distance_matrix |
| |
|
| | def visualize_hierarchy(topic_model, |
| | orientation: str = "left", |
| | topics: List[int] = None, |
| | top_n_topics: int = None, |
| | custom_labels: Union[bool, str] = False, |
| | title: str = "<b>Hierarchical Clustering</b>", |
| | width: int = 1000, |
| | height: int = 600, |
| | hierarchical_topics: pd.DataFrame = None, |
| | linkage_function: Callable[[csr_matrix], np.ndarray] = None, |
| | distance_function: Callable[[csr_matrix], csr_matrix] = None, |
| | color_threshold: int = 1) -> go.Figure: |
| | """ Visualize a hierarchical structure of the topics |
| | |
| | A ward linkage function is used to perform the |
| | hierarchical clustering based on the cosine distance |
| | matrix between topic embeddings. |
| | |
| | Arguments: |
| | topic_model: A fitted BERTopic instance. |
| | orientation: The orientation of the figure. |
| | Either 'left' or 'bottom' |
| | topics: A selection of topics to visualize |
| | top_n_topics: Only select the top n most frequent topics |
| | custom_labels: If bool, whether to use custom topic labels that were defined using |
| | `topic_model.set_topic_labels`. |
| | If `str`, it uses labels from other aspects, e.g., "Aspect1". |
| | NOTE: Custom labels are only generated for the original |
| | un-merged topics. |
| | title: Title of the plot. |
| | width: The width of the figure. Only works if orientation is set to 'left' |
| | height: The height of the figure. Only works if orientation is set to 'bottom' |
| | hierarchical_topics: A dataframe that contains a hierarchy of topics |
| | represented by their parents and their children. |
| | NOTE: The hierarchical topic names are only visualized |
| | if both `topics` and `top_n_topics` are not set. |
| | linkage_function: The linkage function to use. Default is: |
| | `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` |
| | NOTE: Make sure to use the same `linkage_function` as used |
| | in `topic_model.hierarchical_topics`. |
| | distance_function: The distance function to use on the c-TF-IDF matrix. Default is: |
| | `lambda x: 1 - cosine_similarity(x)`. |
| | You can pass any function that returns either a square matrix of |
| | shape (n_samples, n_samples) with zeros on the diagonal and |
| | non-negative values or condensed distance matrix of shape |
| | (n_samples * (n_samples - 1) / 2,) containing the upper |
| | triangular of the distance matrix. |
| | NOTE: Make sure to use the same `distance_function` as used |
| | in `topic_model.hierarchical_topics`. |
| | color_threshold: Value at which the separation of clusters will be made which |
| | will result in different colors for different clusters. |
| | A higher value will typically lead in less colored clusters. |
| | |
| | Returns: |
| | fig: A plotly figure |
| | |
| | Examples: |
| | |
| | To visualize the hierarchical structure of |
| | topics simply run: |
| | |
| | ```python |
| | topic_model.visualize_hierarchy() |
| | ``` |
| | |
| | If you also want the labels visualized of hierarchical topics, |
| | run the following: |
| | |
| | ```python |
| | # Extract hierarchical topics and their representations |
| | hierarchical_topics = topic_model.hierarchical_topics(docs) |
| | |
| | # Visualize these representations |
| | topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) |
| | ``` |
| | |
| | If you want to save the resulting figure: |
| | |
| | ```python |
| | fig = topic_model.visualize_hierarchy() |
| | fig.write_html("path/to/file.html") |
| | ``` |
| | <iframe src="../../getting_started/visualization/hierarchy.html" |
| | style="width:1000px; height: 680px; border: 0px;""></iframe> |
| | """ |
| | if distance_function is None: |
| | distance_function = lambda x: 1 - cosine_similarity(x) |
| |
|
| | if linkage_function is None: |
| | linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) |
| |
|
| | |
| | freq_df = topic_model.get_topic_freq() |
| | freq_df = freq_df.loc[freq_df.Topic != -1, :] |
| | if topics is not None: |
| | topics = list(topics) |
| | elif top_n_topics is not None: |
| | topics = sorted(freq_df.Topic.to_list()[:top_n_topics]) |
| | else: |
| | topics = sorted(freq_df.Topic.to_list()) |
| |
|
| | |
| | all_topics = sorted(list(topic_model.get_topics().keys())) |
| | indices = np.array([all_topics.index(topic) for topic in topics]) |
| |
|
| | |
| | if topic_model.c_tf_idf_ is not None: |
| | embeddings = topic_model.c_tf_idf_[indices] |
| | else: |
| | embeddings = np.array(topic_model.topic_embeddings_)[indices] |
| | |
| | |
| | if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()): |
| | annotations = _get_annotations(topic_model=topic_model, |
| | hierarchical_topics=hierarchical_topics, |
| | embeddings=embeddings, |
| | distance_function=distance_function, |
| | linkage_function=linkage_function, |
| | orientation=orientation, |
| | custom_labels=custom_labels) |
| | else: |
| | annotations = None |
| |
|
| | |
| | distance_function_viz = lambda x: validate_distance_matrix( |
| | distance_function(x), embeddings.shape[0]) |
| | |
| | fig = ff.create_dendrogram(embeddings, |
| | orientation=orientation, |
| | distfun=distance_function_viz, |
| | linkagefun=linkage_function, |
| | hovertext=annotations, |
| | color_threshold=color_threshold) |
| |
|
| | |
| | axis = "yaxis" if orientation == "left" else "xaxis" |
| | if isinstance(custom_labels, str): |
| | new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]] |
| | new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] |
| | new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] |
| | elif topic_model.custom_labels_ is not None and custom_labels: |
| | new_labels = [topic_model.custom_labels_[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]] |
| | else: |
| | new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) |
| | for x in fig.layout[axis]["ticktext"]] |
| | new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] |
| | new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] |
| |
|
| | |
| | fig.update_layout( |
| | plot_bgcolor='#ECEFF1', |
| | template="plotly_white", |
| | title={ |
| | 'text': f"{title}", |
| | 'x': 0.5, |
| | 'xanchor': 'center', |
| | 'yanchor': 'top', |
| | 'font': dict( |
| | size=22, |
| | color="Black") |
| | }, |
| | hoverlabel=dict( |
| | bgcolor="white", |
| | font_size=16, |
| | font_family="Rockwell" |
| | ), |
| | ) |
| |
|
| | |
| | if orientation == "left": |
| | fig.update_layout(height=200 + (15 * len(topics)), |
| | width=width, |
| | yaxis=dict(tickmode="array", |
| | ticktext=new_labels)) |
| |
|
| | |
| | y_max = max([trace['y'].max() + 5 for trace in fig['data']]) |
| | y_min = min([trace['y'].min() - 5 for trace in fig['data']]) |
| | fig.update_layout(yaxis=dict(range=[y_min, y_max])) |
| |
|
| | else: |
| | fig.update_layout(width=200 + (15 * len(topics)), |
| | height=height, |
| | xaxis=dict(tickmode="array", |
| | ticktext=new_labels)) |
| |
|
| | if hierarchical_topics is not None: |
| | for index in [0, 3]: |
| | axis = "x" if orientation == "left" else "y" |
| | xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] |
| | ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] |
| | hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] |
| |
|
| | fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black', |
| | hovertext=hovertext, hoverinfo="text", |
| | mode='markers', showlegend=False)) |
| | return fig |
| |
|
| |
|
| | def _get_annotations(topic_model, |
| | hierarchical_topics: pd.DataFrame, |
| | embeddings: csr_matrix, |
| | linkage_function: Callable[[csr_matrix], np.ndarray], |
| | distance_function: Callable[[csr_matrix], csr_matrix], |
| | orientation: str, |
| | custom_labels: bool = False) -> List[List[str]]: |
| |
|
| | """ Get annotations by replicating linkage function calculation in scipy |
| | |
| | Arguments |
| | topic_model: A fitted BERTopic instance. |
| | hierarchical_topics: A dataframe that contains a hierarchy of topics |
| | represented by their parents and their children. |
| | NOTE: The hierarchical topic names are only visualized |
| | if both `topics` and `top_n_topics` are not set. |
| | embeddings: The c-TF-IDF matrix on which to model the hierarchy |
| | linkage_function: The linkage function to use. Default is: |
| | `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` |
| | NOTE: Make sure to use the same `linkage_function` as used |
| | in `topic_model.hierarchical_topics`. |
| | distance_function: The distance function to use on the c-TF-IDF matrix. Default is: |
| | `lambda x: 1 - cosine_similarity(x)`. |
| | You can pass any function that returns either a square matrix of |
| | shape (n_samples, n_samples) with zeros on the diagonal and |
| | non-negative values or condensed distance matrix of shape |
| | (n_samples * (n_samples - 1) / 2,) containing the upper |
| | triangular of the distance matrix. |
| | NOTE: Make sure to use the same `distance_function` as used |
| | in `topic_model.hierarchical_topics`. |
| | orientation: The orientation of the figure. |
| | Either 'left' or 'bottom' |
| | custom_labels: Whether to use custom topic labels that were defined using |
| | `topic_model.set_topic_labels`. |
| | NOTE: Custom labels are only generated for the original |
| | un-merged topics. |
| | |
| | Returns: |
| | text_annotations: Annotations to be used within Plotly's `ff.create_dendogram` |
| | """ |
| | df = hierarchical_topics.loc[hierarchical_topics.Parent_Name != "Top", :] |
| |
|
| | |
| | X = distance_function(embeddings) |
| | X = validate_distance_matrix(X, embeddings.shape[0]) |
| |
|
| | |
| | Z = linkage_function(X) |
| | P = sch.dendrogram(Z, orientation=orientation, no_plot=True) |
| |
|
| | |
| | x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10) |
| | x_topic = dict(zip(P['leaves'], x_ticks)) |
| |
|
| | topic_vals = dict() |
| | for key, val in x_topic.items(): |
| | topic_vals[val] = [key] |
| |
|
| | parent_topic = dict(zip(df.Parent_ID, df.Topics)) |
| |
|
| | |
| | text_annotations = [] |
| | for index, trace in enumerate(P['icoord']): |
| | fst_topic = topic_vals[trace[0]] |
| | scnd_topic = topic_vals[trace[2]] |
| |
|
| | if len(fst_topic) == 1: |
| | if isinstance(custom_labels, str): |
| | fst_name = f"{fst_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][fst_topic[0]]))[0][:3]) |
| | elif topic_model.custom_labels_ is not None and custom_labels: |
| | fst_name = topic_model.custom_labels_[fst_topic[0] + topic_model._outliers] |
| | else: |
| | fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5]) |
| | else: |
| | for key, value in parent_topic.items(): |
| | if set(value) == set(fst_topic): |
| | fst_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] |
| |
|
| | if len(scnd_topic) == 1: |
| | if isinstance(custom_labels, str): |
| | scnd_name = f"{scnd_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][scnd_topic[0]]))[0][:3]) |
| | elif topic_model.custom_labels_ is not None and custom_labels: |
| | scnd_name = topic_model.custom_labels_[scnd_topic[0] + topic_model._outliers] |
| | else: |
| | scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5]) |
| | else: |
| | for key, value in parent_topic.items(): |
| | if set(value) == set(scnd_topic): |
| | scnd_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] |
| |
|
| | text_annotations.append([fst_name, "", "", scnd_name]) |
| |
|
| | center = (trace[0] + trace[2]) / 2 |
| | topic_vals[center] = fst_topic + scnd_topic |
| |
|
| | return text_annotations |
| |
|