| | import copy |
| | import pytest |
| |
|
| |
|
| | @pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), |
| | ('base_topic_model'), |
| | ('custom_topic_model'), |
| | ('merged_topic_model'), |
| | ('reduced_topic_model'), |
| | ('online_topic_model')]) |
| | def test_merge(model, documents, request): |
| | topic_model = copy.deepcopy(request.getfixturevalue(model)) |
| | nr_topics = len(set(topic_model.topics_)) |
| |
|
| | topics_to_merge = [1, 2] |
| | topic_model.merge_topics(documents, topics_to_merge) |
| | mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_)) |
| | mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_] |
| |
|
| | assert nr_topics == len(set(topic_model.topics_)) + 1 |
| | assert topic_model.get_topic_info().Count.sum() == len(documents) |
| | if model == "online_topic_model": |
| | assert mapped_labels == topic_model.topics_[950:] |
| | else: |
| | assert mapped_labels == topic_model.topics_ |
| |
|
| | topics_to_merge = [1, 2] |
| | topic_model.merge_topics(documents, topics_to_merge) |
| | mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_)) |
| | mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_] |
| |
|
| | assert nr_topics == len(set(topic_model.topics_)) + 2 |
| | assert topic_model.get_topic_info().Count.sum() == len(documents) |
| | if model == "online_topic_model": |
| | assert mapped_labels == topic_model.topics_[950:] |
| | else: |
| | assert mapped_labels == topic_model.topics_ |
| |
|