|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from typing import Callable, List, Union |
|
|
from scipy.sparse import csr_matrix |
|
|
from scipy.cluster import hierarchy as sch |
|
|
from scipy.spatial.distance import squareform |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
import plotly.graph_objects as go |
|
|
import plotly.figure_factory as ff |
|
|
|
|
|
from bertopic._utils import validate_distance_matrix |
|
|
|
|
|
def visualize_hierarchy(topic_model, |
|
|
orientation: str = "left", |
|
|
topics: List[int] = None, |
|
|
top_n_topics: int = None, |
|
|
custom_labels: Union[bool, str] = False, |
|
|
title: str = "<b>Hierarchical Clustering</b>", |
|
|
width: int = 1000, |
|
|
height: int = 600, |
|
|
hierarchical_topics: pd.DataFrame = None, |
|
|
linkage_function: Callable[[csr_matrix], np.ndarray] = None, |
|
|
distance_function: Callable[[csr_matrix], csr_matrix] = None, |
|
|
color_threshold: int = 1) -> go.Figure: |
|
|
""" Visualize a hierarchical structure of the topics |
|
|
|
|
|
A ward linkage function is used to perform the |
|
|
hierarchical clustering based on the cosine distance |
|
|
matrix between topic embeddings. |
|
|
|
|
|
Arguments: |
|
|
topic_model: A fitted BERTopic instance. |
|
|
orientation: The orientation of the figure. |
|
|
Either 'left' or 'bottom' |
|
|
topics: A selection of topics to visualize |
|
|
top_n_topics: Only select the top n most frequent topics |
|
|
custom_labels: If bool, whether to use custom topic labels that were defined using |
|
|
`topic_model.set_topic_labels`. |
|
|
If `str`, it uses labels from other aspects, e.g., "Aspect1". |
|
|
NOTE: Custom labels are only generated for the original |
|
|
un-merged topics. |
|
|
title: Title of the plot. |
|
|
width: The width of the figure. Only works if orientation is set to 'left' |
|
|
height: The height of the figure. Only works if orientation is set to 'bottom' |
|
|
hierarchical_topics: A dataframe that contains a hierarchy of topics |
|
|
represented by their parents and their children. |
|
|
NOTE: The hierarchical topic names are only visualized |
|
|
if both `topics` and `top_n_topics` are not set. |
|
|
linkage_function: The linkage function to use. Default is: |
|
|
`lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` |
|
|
NOTE: Make sure to use the same `linkage_function` as used |
|
|
in `topic_model.hierarchical_topics`. |
|
|
distance_function: The distance function to use on the c-TF-IDF matrix. Default is: |
|
|
`lambda x: 1 - cosine_similarity(x)`. |
|
|
You can pass any function that returns either a square matrix of |
|
|
shape (n_samples, n_samples) with zeros on the diagonal and |
|
|
non-negative values or condensed distance matrix of shape |
|
|
(n_samples * (n_samples - 1) / 2,) containing the upper |
|
|
triangular of the distance matrix. |
|
|
NOTE: Make sure to use the same `distance_function` as used |
|
|
in `topic_model.hierarchical_topics`. |
|
|
color_threshold: Value at which the separation of clusters will be made which |
|
|
will result in different colors for different clusters. |
|
|
A higher value will typically lead in less colored clusters. |
|
|
|
|
|
Returns: |
|
|
fig: A plotly figure |
|
|
|
|
|
Examples: |
|
|
|
|
|
To visualize the hierarchical structure of |
|
|
topics simply run: |
|
|
|
|
|
```python |
|
|
topic_model.visualize_hierarchy() |
|
|
``` |
|
|
|
|
|
If you also want the labels visualized of hierarchical topics, |
|
|
run the following: |
|
|
|
|
|
```python |
|
|
# Extract hierarchical topics and their representations |
|
|
hierarchical_topics = topic_model.hierarchical_topics(docs) |
|
|
|
|
|
# Visualize these representations |
|
|
topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) |
|
|
``` |
|
|
|
|
|
If you want to save the resulting figure: |
|
|
|
|
|
```python |
|
|
fig = topic_model.visualize_hierarchy() |
|
|
fig.write_html("path/to/file.html") |
|
|
``` |
|
|
<iframe src="../../getting_started/visualization/hierarchy.html" |
|
|
style="width:1000px; height: 680px; border: 0px;""></iframe> |
|
|
""" |
|
|
if distance_function is None: |
|
|
distance_function = lambda x: 1 - cosine_similarity(x) |
|
|
|
|
|
if linkage_function is None: |
|
|
linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) |
|
|
|
|
|
|
|
|
freq_df = topic_model.get_topic_freq() |
|
|
freq_df = freq_df.loc[freq_df.Topic != -1, :] |
|
|
if topics is not None: |
|
|
topics = list(topics) |
|
|
elif top_n_topics is not None: |
|
|
topics = sorted(freq_df.Topic.to_list()[:top_n_topics]) |
|
|
else: |
|
|
topics = sorted(freq_df.Topic.to_list()) |
|
|
|
|
|
|
|
|
all_topics = sorted(list(topic_model.get_topics().keys())) |
|
|
indices = np.array([all_topics.index(topic) for topic in topics]) |
|
|
|
|
|
|
|
|
if topic_model.c_tf_idf_ is not None: |
|
|
embeddings = topic_model.c_tf_idf_[indices] |
|
|
else: |
|
|
embeddings = np.array(topic_model.topic_embeddings_)[indices] |
|
|
|
|
|
|
|
|
if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()): |
|
|
annotations = _get_annotations(topic_model=topic_model, |
|
|
hierarchical_topics=hierarchical_topics, |
|
|
embeddings=embeddings, |
|
|
distance_function=distance_function, |
|
|
linkage_function=linkage_function, |
|
|
orientation=orientation, |
|
|
custom_labels=custom_labels) |
|
|
else: |
|
|
annotations = None |
|
|
|
|
|
|
|
|
distance_function_viz = lambda x: validate_distance_matrix( |
|
|
distance_function(x), embeddings.shape[0]) |
|
|
|
|
|
fig = ff.create_dendrogram(embeddings, |
|
|
orientation=orientation, |
|
|
distfun=distance_function_viz, |
|
|
linkagefun=linkage_function, |
|
|
hovertext=annotations, |
|
|
color_threshold=color_threshold) |
|
|
|
|
|
|
|
|
axis = "yaxis" if orientation == "left" else "xaxis" |
|
|
if isinstance(custom_labels, str): |
|
|
new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]] |
|
|
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] |
|
|
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] |
|
|
elif topic_model.custom_labels_ is not None and custom_labels: |
|
|
new_labels = [topic_model.custom_labels_[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]] |
|
|
else: |
|
|
new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) |
|
|
for x in fig.layout[axis]["ticktext"]] |
|
|
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] |
|
|
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
plot_bgcolor='#ECEFF1', |
|
|
template="plotly_white", |
|
|
title={ |
|
|
'text': f"{title}", |
|
|
'x': 0.5, |
|
|
'xanchor': 'center', |
|
|
'yanchor': 'top', |
|
|
'font': dict( |
|
|
size=22, |
|
|
color="Black") |
|
|
}, |
|
|
hoverlabel=dict( |
|
|
bgcolor="white", |
|
|
font_size=16, |
|
|
font_family="Rockwell" |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
if orientation == "left": |
|
|
fig.update_layout(height=200 + (15 * len(topics)), |
|
|
width=width, |
|
|
yaxis=dict(tickmode="array", |
|
|
ticktext=new_labels)) |
|
|
|
|
|
|
|
|
y_max = max([trace['y'].max() + 5 for trace in fig['data']]) |
|
|
y_min = min([trace['y'].min() - 5 for trace in fig['data']]) |
|
|
fig.update_layout(yaxis=dict(range=[y_min, y_max])) |
|
|
|
|
|
else: |
|
|
fig.update_layout(width=200 + (15 * len(topics)), |
|
|
height=height, |
|
|
xaxis=dict(tickmode="array", |
|
|
ticktext=new_labels)) |
|
|
|
|
|
if hierarchical_topics is not None: |
|
|
for index in [0, 3]: |
|
|
axis = "x" if orientation == "left" else "y" |
|
|
xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] |
|
|
ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] |
|
|
hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] |
|
|
|
|
|
fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black', |
|
|
hovertext=hovertext, hoverinfo="text", |
|
|
mode='markers', showlegend=False)) |
|
|
return fig |
|
|
|
|
|
|
|
|
def _get_annotations(topic_model, |
|
|
hierarchical_topics: pd.DataFrame, |
|
|
embeddings: csr_matrix, |
|
|
linkage_function: Callable[[csr_matrix], np.ndarray], |
|
|
distance_function: Callable[[csr_matrix], csr_matrix], |
|
|
orientation: str, |
|
|
custom_labels: bool = False) -> List[List[str]]: |
|
|
|
|
|
""" Get annotations by replicating linkage function calculation in scipy |
|
|
|
|
|
Arguments |
|
|
topic_model: A fitted BERTopic instance. |
|
|
hierarchical_topics: A dataframe that contains a hierarchy of topics |
|
|
represented by their parents and their children. |
|
|
NOTE: The hierarchical topic names are only visualized |
|
|
if both `topics` and `top_n_topics` are not set. |
|
|
embeddings: The c-TF-IDF matrix on which to model the hierarchy |
|
|
linkage_function: The linkage function to use. Default is: |
|
|
`lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` |
|
|
NOTE: Make sure to use the same `linkage_function` as used |
|
|
in `topic_model.hierarchical_topics`. |
|
|
distance_function: The distance function to use on the c-TF-IDF matrix. Default is: |
|
|
`lambda x: 1 - cosine_similarity(x)`. |
|
|
You can pass any function that returns either a square matrix of |
|
|
shape (n_samples, n_samples) with zeros on the diagonal and |
|
|
non-negative values or condensed distance matrix of shape |
|
|
(n_samples * (n_samples - 1) / 2,) containing the upper |
|
|
triangular of the distance matrix. |
|
|
NOTE: Make sure to use the same `distance_function` as used |
|
|
in `topic_model.hierarchical_topics`. |
|
|
orientation: The orientation of the figure. |
|
|
Either 'left' or 'bottom' |
|
|
custom_labels: Whether to use custom topic labels that were defined using |
|
|
`topic_model.set_topic_labels`. |
|
|
NOTE: Custom labels are only generated for the original |
|
|
un-merged topics. |
|
|
|
|
|
Returns: |
|
|
text_annotations: Annotations to be used within Plotly's `ff.create_dendogram` |
|
|
""" |
|
|
df = hierarchical_topics.loc[hierarchical_topics.Parent_Name != "Top", :] |
|
|
|
|
|
|
|
|
X = distance_function(embeddings) |
|
|
X = validate_distance_matrix(X, embeddings.shape[0]) |
|
|
|
|
|
|
|
|
Z = linkage_function(X) |
|
|
P = sch.dendrogram(Z, orientation=orientation, no_plot=True) |
|
|
|
|
|
|
|
|
x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10) |
|
|
x_topic = dict(zip(P['leaves'], x_ticks)) |
|
|
|
|
|
topic_vals = dict() |
|
|
for key, val in x_topic.items(): |
|
|
topic_vals[val] = [key] |
|
|
|
|
|
parent_topic = dict(zip(df.Parent_ID, df.Topics)) |
|
|
|
|
|
|
|
|
text_annotations = [] |
|
|
for index, trace in enumerate(P['icoord']): |
|
|
fst_topic = topic_vals[trace[0]] |
|
|
scnd_topic = topic_vals[trace[2]] |
|
|
|
|
|
if len(fst_topic) == 1: |
|
|
if isinstance(custom_labels, str): |
|
|
fst_name = f"{fst_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][fst_topic[0]]))[0][:3]) |
|
|
elif topic_model.custom_labels_ is not None and custom_labels: |
|
|
fst_name = topic_model.custom_labels_[fst_topic[0] + topic_model._outliers] |
|
|
else: |
|
|
fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5]) |
|
|
else: |
|
|
for key, value in parent_topic.items(): |
|
|
if set(value) == set(fst_topic): |
|
|
fst_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] |
|
|
|
|
|
if len(scnd_topic) == 1: |
|
|
if isinstance(custom_labels, str): |
|
|
scnd_name = f"{scnd_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][scnd_topic[0]]))[0][:3]) |
|
|
elif topic_model.custom_labels_ is not None and custom_labels: |
|
|
scnd_name = topic_model.custom_labels_[scnd_topic[0] + topic_model._outliers] |
|
|
else: |
|
|
scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5]) |
|
|
else: |
|
|
for key, value in parent_topic.items(): |
|
|
if set(value) == set(scnd_topic): |
|
|
scnd_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] |
|
|
|
|
|
text_annotations.append([fst_name, "", "", scnd_name]) |
|
|
|
|
|
center = (trace[0] + trace[2]) / 2 |
|
|
topic_vals[center] = fst_topic + scnd_topic |
|
|
|
|
|
return text_annotations |
|
|
|