File size: 3,804 Bytes
11c72a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
from bertopic import BERTopic
from backend.datasets.utils import _utils
from backend.datasets.utils.logger import Logger

logger = Logger("WARNING")


class DBERTopicTrainer:
    def __init__(self,
                 dataset,
                 num_topics=20,
                 num_top_words=15,
                 nr_bins=20,
                 global_tuning=True,
                 evolution_tuning=True,
                 datetime_format=None,
                 verbose=False):

        self.dataset = dataset
        self.docs = dataset.raw_documents
        self.num_topics=num_topics
        # self.timestamps = dataset.train_times
        self.vocab = dataset.vocab
        self.num_top_words = num_top_words
        # self.nr_bins = nr_bins
        # self.global_tuning = global_tuning
        # self.evolution_tuning = evolution_tuning
        # self.datetime_format = datetime_format
        self.verbose = verbose

        if verbose:
            logger.set_level("DEBUG")
        else:
            logger.set_level("WARNING")

    def train(self, timestamps, datetime_format='%Y'):
        logger.info("Fitting BERTopic...")
        self.model = BERTopic(nr_topics=self.num_topics, verbose=self.verbose)
        self.topics, _ = self.model.fit_transform(self.docs)

        logger.info("Running topics_over_time...")
        self.topics_over_time_df = self.model.topics_over_time(
            docs=self.docs,
            timestamps=timestamps,
            nr_bins=len(set(timestamps)),
            datetime_format=datetime_format
        )

        self.unique_timestamps = sorted(self.topics_over_time_df["Timestamp"].unique())
        self.unique_topics = sorted(self.topics_over_time_df["Topic"].unique())
        self.vocab = self.model.vectorizer_model.get_feature_names_out()
        self.V = len(self.vocab)
        self.K = len(self.unique_topics)
        self.T = len(self.unique_timestamps)

    def get_beta(self):
        logger.info("Generating β matrix...")

        beta = np.zeros((self.T, self.K, self.V))
        topic_to_index = {topic: idx for idx, topic in enumerate(self.unique_topics)}
        timestamp_to_index = {timestamp: idx for idx, timestamp in enumerate(self.unique_timestamps)}

        # Extract topic representations at each time
        for t_idx, timestamp in enumerate(self.unique_timestamps):
            selection = self.topics_over_time_df[self.topics_over_time_df["Timestamp"] == timestamp]
            for _, row in selection.iterrows():
                topic = row["Topic"]
                words = row["Words"].split(", ")
                if topic not in topic_to_index:
                    continue
                k = topic_to_index[topic]
                for word in words:
                    if word in self.vocab:
                        v = np.where(self.vocab == word)[0][0]
                        beta[t_idx, k, v] += 1.0

        # Normalize each β_tk to be a probability distribution
        beta = beta / (beta.sum(axis=2, keepdims=True) + 1e-10)
        return beta

    def get_top_words(self, num_top_words=None):
        if num_top_words is None:
            num_top_words = self.num_top_words
        beta = self.get_beta()
        top_words_list = list()
        for time in range(beta.shape[0]):
            top_words = _utils.get_top_words(beta[time], self.vocab, num_top_words, self.verbose)
            top_words_list.append(top_words)
        return top_words_list

    def get_theta(self):
        # Not applicable for BERTopic; can return topic assignments or soft topic distributions if required
        logger.warning("get_theta is not implemented for BERTopic.")
        return None

    def export_theta(self):
        logger.warning("export_theta is not implemented for BERTopic.")
        return None, None