File size: 5,145 Bytes
19b102a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import numpy as np
from typing import List, Union
from scipy.cluster.hierarchy import fcluster, linkage
from sklearn.metrics.pairwise import cosine_similarity
import plotly.express as px
import plotly.graph_objects as go
def visualize_heatmap(topic_model,
topics: List[int] = None,
top_n_topics: int = None,
n_clusters: int = None,
custom_labels: Union[bool, str] = False,
title: str = "<b>Similarity Matrix</b>",
width: int = 800,
height: int = 800) -> go.Figure:
""" Visualize a heatmap of the topic's similarity matrix
Based on the cosine similarity matrix between topic embeddings,
a heatmap is created showing the similarity between topics.
Arguments:
topic_model: A fitted BERTopic instance.
topics: A selection of topics to visualize.
top_n_topics: Only select the top n most frequent topics.
n_clusters: Create n clusters and order the similarity
matrix by those clusters.
custom_labels: If bool, whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
If `str`, it uses labels from other aspects, e.g., "Aspect1".
title: Title of the plot.
width: The width of the figure.
height: The height of the figure.
Returns:
fig: A plotly figure
Examples:
To visualize the similarity matrix of
topics simply run:
```python
topic_model.visualize_heatmap()
```
Or if you want to save the resulting figure:
```python
fig = topic_model.visualize_heatmap()
fig.write_html("path/to/file.html")
```
<iframe src="../../getting_started/visualization/heatmap.html"
style="width:1000px; height: 720px; border: 0px;""></iframe>
"""
# Select topic embeddings
if topic_model.topic_embeddings_ is not None:
embeddings = np.array(topic_model.topic_embeddings_)[topic_model._outliers:]
else:
embeddings = topic_model.c_tf_idf_[topic_model._outliers:]
# Select topics based on top_n and topics args
freq_df = topic_model.get_topic_freq()
freq_df = freq_df.loc[freq_df.Topic != -1, :]
if topics is not None:
topics = list(topics)
elif top_n_topics is not None:
topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
else:
topics = sorted(freq_df.Topic.to_list())
# Order heatmap by similar clusters of topics
sorted_topics = topics
if n_clusters:
if n_clusters >= len(set(topics)):
raise ValueError("Make sure to set `n_clusters` lower than "
"the total number of unique topics.")
distance_matrix = cosine_similarity(embeddings[topics])
Z = linkage(distance_matrix, 'ward')
clusters = fcluster(Z, t=n_clusters, criterion='maxclust')
# Extract new order of topics
mapping = {cluster: [] for cluster in clusters}
for topic, cluster in zip(topics, clusters):
mapping[cluster].append(topic)
mapping = [cluster for cluster in mapping.values()]
sorted_topics = [topic for cluster in mapping for topic in cluster]
# Select embeddings
indices = np.array([topics.index(topic) for topic in sorted_topics])
embeddings = embeddings[indices]
distance_matrix = cosine_similarity(embeddings)
# Create labels
if isinstance(custom_labels, str):
new_labels = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in sorted_topics]
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
elif topic_model.custom_labels_ is not None and custom_labels:
new_labels = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in sorted_topics]
else:
new_labels = [[[str(topic), None]] + topic_model.get_topic(topic) for topic in sorted_topics]
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
fig = px.imshow(distance_matrix,
labels=dict(color="Similarity Score"),
x=new_labels,
y=new_labels,
color_continuous_scale='GnBu'
)
fig.update_layout(
title={
'text': f"{title}",
'y': .95,
'x': 0.55,
'xanchor': 'center',
'yanchor': 'top',
'font': dict(
size=22,
color="Black")
},
width=width,
height=height,
hoverlabel=dict(
bgcolor="white",
font_size=16,
font_family="Rockwell"
),
)
fig.update_layout(showlegend=True)
fig.update_layout(legend_title_text='Trend')
return fig
|