| import copy |
| import pytest |
|
|
|
|
| @pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), |
| ('base_topic_model'), |
| ('custom_topic_model'), |
| ('merged_topic_model'), |
| ('reduced_topic_model'), |
| ('online_topic_model')]) |
| def test_merge(model, documents, request): |
| topic_model = copy.deepcopy(request.getfixturevalue(model)) |
| nr_topics = len(set(topic_model.topics_)) |
|
|
| topics_to_merge = [1, 2] |
| topic_model.merge_topics(documents, topics_to_merge) |
| mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_)) |
| mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_] |
|
|
| assert nr_topics == len(set(topic_model.topics_)) + 1 |
| assert topic_model.get_topic_info().Count.sum() == len(documents) |
| if model == "online_topic_model": |
| assert mapped_labels == topic_model.topics_[950:] |
| else: |
| assert mapped_labels == topic_model.topics_ |
|
|
| topics_to_merge = [1, 2] |
| topic_model.merge_topics(documents, topics_to_merge) |
| mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_)) |
| mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_] |
|
|
| assert nr_topics == len(set(topic_model.topics_)) + 2 |
| assert topic_model.get_topic_info().Count.sum() == len(documents) |
| if model == "online_topic_model": |
| assert mapped_labels == topic_model.topics_[950:] |
| else: |
| assert mapped_labels == topic_model.topics_ |
|
|