TopicModelingRepo / BERTopic /tests /test_bertopic.py
kisejin's picture
Upload 261 files
19b102a verified
import copy
import pytest
from bertopic import BERTopic
@pytest.mark.parametrize(
'model',
[
("base_topic_model"),
('kmeans_pca_topic_model'),
('custom_topic_model'),
('merged_topic_model'),
('reduced_topic_model'),
('online_topic_model'),
('supervised_topic_model'),
('representation_topic_model'),
('zeroshot_topic_model')
])
def test_full_model(model, documents, request):
""" Tests the entire pipeline in one go. This serves as a sanity check to see if the default
settings result in a good separation of topics.
NOTE: This does not cover all cases but merely combines it all together
"""
topic_model = copy.deepcopy(request.getfixturevalue(model))
if model == "base_topic_model":
topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2")
topic_model = BERTopic.load("model_dir")
topics = topic_model.topics_
for topic in set(topics):
words = topic_model.get_topic(topic)[:10]
assert len(words) == 10
for topic in topic_model.get_topic_freq().Topic:
words = topic_model.get_topic(topic)[:10]
assert len(words) == 10
assert len(topic_model.get_topic_freq()) > 2
assert len(topic_model.get_topics()) == len(topic_model.get_topic_freq())
# Test extraction of document info
document_info = topic_model.get_document_info(documents)
assert len(document_info) == len(documents)
# Test transform
doc = "This is a new document to predict."
topics_test, probs_test = topic_model.transform([doc, doc])
assert len(topics_test) == 2
# Test topics over time
timestamps = [i % 10 for i in range(len(documents))]
topics_over_time = topic_model.topics_over_time(documents, timestamps)
assert topics_over_time.Frequency.sum() == len(documents)
assert len(topics_over_time.Topic.unique()) == len(set(topics))
# Test hierarchical topics
hier_topics = topic_model.hierarchical_topics(documents)
assert len(hier_topics) > 0
assert hier_topics.Parent_ID.astype(int).min() > max(topics)
# Test creation of topic tree
tree = topic_model.get_topic_tree(hier_topics, tight_layout=False)
assert isinstance(tree, str)
assert len(tree) > 10
# Test find topic
similar_topics, similarity = topic_model.find_topics("query", top_n=2)
assert len(similar_topics) == 2
assert len(similarity) == 2
assert max(similarity) <= 1
# Test topic reduction
nr_topics = len(set(topics))
nr_topics = 2 if nr_topics < 2 else nr_topics - 1
topic_model.reduce_topics(documents, nr_topics=nr_topics)
assert len(topic_model.get_topic_freq()) == nr_topics
assert len(topic_model.topics_) == len(topics)
# Test update topics
topic = topic_model.get_topic(1)[:10]
vectorizer_model = topic_model.vectorizer_model
topic_model.update_topics(documents, n_gram_range=(2, 2))
updated_topic = topic_model.get_topic(1)[:10]
topic_model.update_topics(documents, vectorizer_model=vectorizer_model)
original_topic = topic_model.get_topic(1)[:10]
assert topic != updated_topic
if topic_model.representation_model is not None:
assert topic != original_topic
# Test updating topic labels
topic_labels = topic_model.generate_topic_labels(nr_words=3, topic_prefix=False, word_length=10, separator=", ")
assert len(topic_labels) == len(set(topic_model.topics_))
# Test setting topic labels
topic_model.set_topic_labels(topic_labels)
assert topic_model.custom_labels_ == topic_labels
# Test merging topics
freq = topic_model.get_topic_freq(0)
topics_to_merge = [0, 1]
topic_model.merge_topics(documents, topics_to_merge)
assert freq < topic_model.get_topic_freq(0)
# Test reduction of outliers
if -1 in topics:
new_topics = topic_model.reduce_outliers(documents, topics, threshold=0.0)
nr_outliers_topic_model = sum([1 for topic in topic_model.topics_ if topic == -1])
nr_outliers_new_topics = sum([1 for topic in new_topics if topic == -1])
if topic_model._outliers == 1:
assert nr_outliers_topic_model > nr_outliers_new_topics
# Combine models
topic_model1 = BERTopic.load("model_dir")
merged_model = BERTopic.merge_models([topic_model, topic_model1])
assert len(merged_model.get_topic_info()) > len(topic_model1.get_topic_info())
assert len(merged_model.get_topic_info()) > len(topic_model.get_topic_info())