File size: 4,607 Bytes
19b102a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import copy
import pytest
from bertopic import BERTopic


@pytest.mark.parametrize(
    'model',
    [
        ("base_topic_model"),
        ('kmeans_pca_topic_model'),
        ('custom_topic_model'),
        ('merged_topic_model'),
        ('reduced_topic_model'),
        ('online_topic_model'),
        ('supervised_topic_model'),
        ('representation_topic_model'),
        ('zeroshot_topic_model')
    ])
def test_full_model(model, documents, request):
    """ Tests the entire pipeline in one go. This serves as a sanity check to see if the default
    settings result in a good separation of topics.

    NOTE: This does not cover all cases but merely combines it all together
    """
    topic_model = copy.deepcopy(request.getfixturevalue(model))
    if model == "base_topic_model":
        topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2")
        topic_model = BERTopic.load("model_dir")
    topics = topic_model.topics_

    for topic in set(topics):
        words = topic_model.get_topic(topic)[:10]
        assert len(words) == 10

    for topic in topic_model.get_topic_freq().Topic:
        words = topic_model.get_topic(topic)[:10]
        assert len(words) == 10

    assert len(topic_model.get_topic_freq()) > 2
    assert len(topic_model.get_topics()) == len(topic_model.get_topic_freq())

    # Test extraction of document info
    document_info = topic_model.get_document_info(documents)
    assert len(document_info) == len(documents)

    # Test transform
    doc = "This is a new document to predict."
    topics_test, probs_test = topic_model.transform([doc, doc])

    assert len(topics_test) == 2

    # Test topics over time
    timestamps = [i % 10 for i in range(len(documents))]
    topics_over_time = topic_model.topics_over_time(documents, timestamps)

    assert topics_over_time.Frequency.sum() == len(documents)
    assert len(topics_over_time.Topic.unique()) == len(set(topics))

    # Test hierarchical topics
    hier_topics = topic_model.hierarchical_topics(documents)

    assert len(hier_topics) > 0
    assert hier_topics.Parent_ID.astype(int).min() > max(topics)

    # Test creation of topic tree
    tree = topic_model.get_topic_tree(hier_topics, tight_layout=False)
    assert isinstance(tree, str)
    assert len(tree) > 10

    # Test find topic
    similar_topics, similarity = topic_model.find_topics("query", top_n=2)
    assert len(similar_topics) == 2
    assert len(similarity) == 2
    assert max(similarity) <= 1

    # Test topic reduction
    nr_topics = len(set(topics))
    nr_topics = 2 if nr_topics < 2 else nr_topics - 1
    topic_model.reduce_topics(documents, nr_topics=nr_topics)

    assert len(topic_model.get_topic_freq()) == nr_topics
    assert len(topic_model.topics_) == len(topics)

    # Test update topics
    topic = topic_model.get_topic(1)[:10]
    vectorizer_model = topic_model.vectorizer_model
    topic_model.update_topics(documents, n_gram_range=(2, 2))

    updated_topic = topic_model.get_topic(1)[:10]

    topic_model.update_topics(documents, vectorizer_model=vectorizer_model)
    original_topic = topic_model.get_topic(1)[:10]

    assert topic != updated_topic
    if topic_model.representation_model is not None:
        assert topic != original_topic

    # Test updating topic labels
    topic_labels = topic_model.generate_topic_labels(nr_words=3, topic_prefix=False, word_length=10, separator=", ")
    assert len(topic_labels) == len(set(topic_model.topics_))

    # Test setting topic labels
    topic_model.set_topic_labels(topic_labels)
    assert topic_model.custom_labels_ == topic_labels

    # Test merging topics
    freq = topic_model.get_topic_freq(0)
    topics_to_merge = [0, 1]
    topic_model.merge_topics(documents, topics_to_merge)
    assert freq < topic_model.get_topic_freq(0)

    # Test reduction of outliers
    if -1 in topics:
        new_topics = topic_model.reduce_outliers(documents, topics, threshold=0.0)
        nr_outliers_topic_model = sum([1 for topic in topic_model.topics_ if topic == -1])
        nr_outliers_new_topics = sum([1 for topic in new_topics if topic == -1])

        if topic_model._outliers == 1:
            assert nr_outliers_topic_model > nr_outliers_new_topics

    # Combine models
    topic_model1 = BERTopic.load("model_dir")
    merged_model = BERTopic.merge_models([topic_model, topic_model1])

    assert len(merged_model.get_topic_info()) > len(topic_model1.get_topic_info())
    assert len(merged_model.get_topic_info()) > len(topic_model.get_topic_info())