| | import copy |
| | import pytest |
| | from bertopic import BERTopic |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | 'model', |
| | [ |
| | ("base_topic_model"), |
| | ('kmeans_pca_topic_model'), |
| | ('custom_topic_model'), |
| | ('merged_topic_model'), |
| | ('reduced_topic_model'), |
| | ('online_topic_model'), |
| | ('supervised_topic_model'), |
| | ('representation_topic_model'), |
| | ('zeroshot_topic_model') |
| | ]) |
| | def test_full_model(model, documents, request): |
| | """ Tests the entire pipeline in one go. This serves as a sanity check to see if the default |
| | settings result in a good separation of topics. |
| | |
| | NOTE: This does not cover all cases but merely combines it all together |
| | """ |
| | topic_model = copy.deepcopy(request.getfixturevalue(model)) |
| | if model == "base_topic_model": |
| | topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2") |
| | topic_model = BERTopic.load("model_dir") |
| | topics = topic_model.topics_ |
| |
|
| | for topic in set(topics): |
| | words = topic_model.get_topic(topic)[:10] |
| | assert len(words) == 10 |
| |
|
| | for topic in topic_model.get_topic_freq().Topic: |
| | words = topic_model.get_topic(topic)[:10] |
| | assert len(words) == 10 |
| |
|
| | assert len(topic_model.get_topic_freq()) > 2 |
| | assert len(topic_model.get_topics()) == len(topic_model.get_topic_freq()) |
| |
|
| | |
| | document_info = topic_model.get_document_info(documents) |
| | assert len(document_info) == len(documents) |
| |
|
| | |
| | doc = "This is a new document to predict." |
| | topics_test, probs_test = topic_model.transform([doc, doc]) |
| |
|
| | assert len(topics_test) == 2 |
| |
|
| | |
| | timestamps = [i % 10 for i in range(len(documents))] |
| | topics_over_time = topic_model.topics_over_time(documents, timestamps) |
| |
|
| | assert topics_over_time.Frequency.sum() == len(documents) |
| | assert len(topics_over_time.Topic.unique()) == len(set(topics)) |
| |
|
| | |
| | hier_topics = topic_model.hierarchical_topics(documents) |
| |
|
| | assert len(hier_topics) > 0 |
| | assert hier_topics.Parent_ID.astype(int).min() > max(topics) |
| |
|
| | |
| | tree = topic_model.get_topic_tree(hier_topics, tight_layout=False) |
| | assert isinstance(tree, str) |
| | assert len(tree) > 10 |
| |
|
| | |
| | similar_topics, similarity = topic_model.find_topics("query", top_n=2) |
| | assert len(similar_topics) == 2 |
| | assert len(similarity) == 2 |
| | assert max(similarity) <= 1 |
| |
|
| | |
| | nr_topics = len(set(topics)) |
| | nr_topics = 2 if nr_topics < 2 else nr_topics - 1 |
| | topic_model.reduce_topics(documents, nr_topics=nr_topics) |
| |
|
| | assert len(topic_model.get_topic_freq()) == nr_topics |
| | assert len(topic_model.topics_) == len(topics) |
| |
|
| | |
| | topic = topic_model.get_topic(1)[:10] |
| | vectorizer_model = topic_model.vectorizer_model |
| | topic_model.update_topics(documents, n_gram_range=(2, 2)) |
| |
|
| | updated_topic = topic_model.get_topic(1)[:10] |
| |
|
| | topic_model.update_topics(documents, vectorizer_model=vectorizer_model) |
| | original_topic = topic_model.get_topic(1)[:10] |
| |
|
| | assert topic != updated_topic |
| | if topic_model.representation_model is not None: |
| | assert topic != original_topic |
| |
|
| | |
| | topic_labels = topic_model.generate_topic_labels(nr_words=3, topic_prefix=False, word_length=10, separator=", ") |
| | assert len(topic_labels) == len(set(topic_model.topics_)) |
| |
|
| | |
| | topic_model.set_topic_labels(topic_labels) |
| | assert topic_model.custom_labels_ == topic_labels |
| |
|
| | |
| | freq = topic_model.get_topic_freq(0) |
| | topics_to_merge = [0, 1] |
| | topic_model.merge_topics(documents, topics_to_merge) |
| | assert freq < topic_model.get_topic_freq(0) |
| |
|
| | |
| | if -1 in topics: |
| | new_topics = topic_model.reduce_outliers(documents, topics, threshold=0.0) |
| | nr_outliers_topic_model = sum([1 for topic in topic_model.topics_ if topic == -1]) |
| | nr_outliers_new_topics = sum([1 for topic in new_topics if topic == -1]) |
| |
|
| | if topic_model._outliers == 1: |
| | assert nr_outliers_topic_model > nr_outliers_new_topics |
| |
|
| | |
| | topic_model1 = BERTopic.load("model_dir") |
| | merged_model = BERTopic.merge_models([topic_model, topic_model1]) |
| |
|
| | assert len(merged_model.get_topic_info()) > len(topic_model1.get_topic_info()) |
| | assert len(merged_model.get_topic_info()) > len(topic_model.get_topic_info()) |
| |
|