|
|
import copy |
|
|
import pytest |
|
|
from bertopic import BERTopic |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
'model', |
|
|
[ |
|
|
("base_topic_model"), |
|
|
('kmeans_pca_topic_model'), |
|
|
('custom_topic_model'), |
|
|
('merged_topic_model'), |
|
|
('reduced_topic_model'), |
|
|
('online_topic_model'), |
|
|
('supervised_topic_model'), |
|
|
('representation_topic_model'), |
|
|
('zeroshot_topic_model') |
|
|
]) |
|
|
def test_full_model(model, documents, request): |
|
|
""" Tests the entire pipeline in one go. This serves as a sanity check to see if the default |
|
|
settings result in a good separation of topics. |
|
|
|
|
|
NOTE: This does not cover all cases but merely combines it all together |
|
|
""" |
|
|
topic_model = copy.deepcopy(request.getfixturevalue(model)) |
|
|
if model == "base_topic_model": |
|
|
topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2") |
|
|
topic_model = BERTopic.load("model_dir") |
|
|
topics = topic_model.topics_ |
|
|
|
|
|
for topic in set(topics): |
|
|
words = topic_model.get_topic(topic)[:10] |
|
|
assert len(words) == 10 |
|
|
|
|
|
for topic in topic_model.get_topic_freq().Topic: |
|
|
words = topic_model.get_topic(topic)[:10] |
|
|
assert len(words) == 10 |
|
|
|
|
|
assert len(topic_model.get_topic_freq()) > 2 |
|
|
assert len(topic_model.get_topics()) == len(topic_model.get_topic_freq()) |
|
|
|
|
|
|
|
|
document_info = topic_model.get_document_info(documents) |
|
|
assert len(document_info) == len(documents) |
|
|
|
|
|
|
|
|
doc = "This is a new document to predict." |
|
|
topics_test, probs_test = topic_model.transform([doc, doc]) |
|
|
|
|
|
assert len(topics_test) == 2 |
|
|
|
|
|
|
|
|
timestamps = [i % 10 for i in range(len(documents))] |
|
|
topics_over_time = topic_model.topics_over_time(documents, timestamps) |
|
|
|
|
|
assert topics_over_time.Frequency.sum() == len(documents) |
|
|
assert len(topics_over_time.Topic.unique()) == len(set(topics)) |
|
|
|
|
|
|
|
|
hier_topics = topic_model.hierarchical_topics(documents) |
|
|
|
|
|
assert len(hier_topics) > 0 |
|
|
assert hier_topics.Parent_ID.astype(int).min() > max(topics) |
|
|
|
|
|
|
|
|
tree = topic_model.get_topic_tree(hier_topics, tight_layout=False) |
|
|
assert isinstance(tree, str) |
|
|
assert len(tree) > 10 |
|
|
|
|
|
|
|
|
similar_topics, similarity = topic_model.find_topics("query", top_n=2) |
|
|
assert len(similar_topics) == 2 |
|
|
assert len(similarity) == 2 |
|
|
assert max(similarity) <= 1 |
|
|
|
|
|
|
|
|
nr_topics = len(set(topics)) |
|
|
nr_topics = 2 if nr_topics < 2 else nr_topics - 1 |
|
|
topic_model.reduce_topics(documents, nr_topics=nr_topics) |
|
|
|
|
|
assert len(topic_model.get_topic_freq()) == nr_topics |
|
|
assert len(topic_model.topics_) == len(topics) |
|
|
|
|
|
|
|
|
topic = topic_model.get_topic(1)[:10] |
|
|
vectorizer_model = topic_model.vectorizer_model |
|
|
topic_model.update_topics(documents, n_gram_range=(2, 2)) |
|
|
|
|
|
updated_topic = topic_model.get_topic(1)[:10] |
|
|
|
|
|
topic_model.update_topics(documents, vectorizer_model=vectorizer_model) |
|
|
original_topic = topic_model.get_topic(1)[:10] |
|
|
|
|
|
assert topic != updated_topic |
|
|
if topic_model.representation_model is not None: |
|
|
assert topic != original_topic |
|
|
|
|
|
|
|
|
topic_labels = topic_model.generate_topic_labels(nr_words=3, topic_prefix=False, word_length=10, separator=", ") |
|
|
assert len(topic_labels) == len(set(topic_model.topics_)) |
|
|
|
|
|
|
|
|
topic_model.set_topic_labels(topic_labels) |
|
|
assert topic_model.custom_labels_ == topic_labels |
|
|
|
|
|
|
|
|
freq = topic_model.get_topic_freq(0) |
|
|
topics_to_merge = [0, 1] |
|
|
topic_model.merge_topics(documents, topics_to_merge) |
|
|
assert freq < topic_model.get_topic_freq(0) |
|
|
|
|
|
|
|
|
if -1 in topics: |
|
|
new_topics = topic_model.reduce_outliers(documents, topics, threshold=0.0) |
|
|
nr_outliers_topic_model = sum([1 for topic in topic_model.topics_ if topic == -1]) |
|
|
nr_outliers_new_topics = sum([1 for topic in new_topics if topic == -1]) |
|
|
|
|
|
if topic_model._outliers == 1: |
|
|
assert nr_outliers_topic_model > nr_outliers_new_topics |
|
|
|
|
|
|
|
|
topic_model1 = BERTopic.load("model_dir") |
|
|
merged_model = BERTopic.merge_models([topic_model, topic_model1]) |
|
|
|
|
|
assert len(merged_model.get_topic_info()) > len(topic_model1.get_topic_info()) |
|
|
assert len(merged_model.get_topic_info()) > len(topic_model.get_topic_info()) |
|
|
|