DTECT / backend /models /DBERTopic_trainer.py
AdhyaSuman's picture
Initial commit with Git LFS for large files
11c72a2
import numpy as np
from bertopic import BERTopic
from backend.datasets.utils import _utils
from backend.datasets.utils.logger import Logger
logger = Logger("WARNING")
class DBERTopicTrainer:
def __init__(self,
dataset,
num_topics=20,
num_top_words=15,
nr_bins=20,
global_tuning=True,
evolution_tuning=True,
datetime_format=None,
verbose=False):
self.dataset = dataset
self.docs = dataset.raw_documents
self.num_topics=num_topics
# self.timestamps = dataset.train_times
self.vocab = dataset.vocab
self.num_top_words = num_top_words
# self.nr_bins = nr_bins
# self.global_tuning = global_tuning
# self.evolution_tuning = evolution_tuning
# self.datetime_format = datetime_format
self.verbose = verbose
if verbose:
logger.set_level("DEBUG")
else:
logger.set_level("WARNING")
def train(self, timestamps, datetime_format='%Y'):
logger.info("Fitting BERTopic...")
self.model = BERTopic(nr_topics=self.num_topics, verbose=self.verbose)
self.topics, _ = self.model.fit_transform(self.docs)
logger.info("Running topics_over_time...")
self.topics_over_time_df = self.model.topics_over_time(
docs=self.docs,
timestamps=timestamps,
nr_bins=len(set(timestamps)),
datetime_format=datetime_format
)
self.unique_timestamps = sorted(self.topics_over_time_df["Timestamp"].unique())
self.unique_topics = sorted(self.topics_over_time_df["Topic"].unique())
self.vocab = self.model.vectorizer_model.get_feature_names_out()
self.V = len(self.vocab)
self.K = len(self.unique_topics)
self.T = len(self.unique_timestamps)
def get_beta(self):
logger.info("Generating β matrix...")
beta = np.zeros((self.T, self.K, self.V))
topic_to_index = {topic: idx for idx, topic in enumerate(self.unique_topics)}
timestamp_to_index = {timestamp: idx for idx, timestamp in enumerate(self.unique_timestamps)}
# Extract topic representations at each time
for t_idx, timestamp in enumerate(self.unique_timestamps):
selection = self.topics_over_time_df[self.topics_over_time_df["Timestamp"] == timestamp]
for _, row in selection.iterrows():
topic = row["Topic"]
words = row["Words"].split(", ")
if topic not in topic_to_index:
continue
k = topic_to_index[topic]
for word in words:
if word in self.vocab:
v = np.where(self.vocab == word)[0][0]
beta[t_idx, k, v] += 1.0
# Normalize each β_tk to be a probability distribution
beta = beta / (beta.sum(axis=2, keepdims=True) + 1e-10)
return beta
def get_top_words(self, num_top_words=None):
if num_top_words is None:
num_top_words = self.num_top_words
beta = self.get_beta()
top_words_list = list()
for time in range(beta.shape[0]):
top_words = _utils.get_top_words(beta[time], self.vocab, num_top_words, self.verbose)
top_words_list.append(top_words)
return top_words_list
def get_theta(self):
# Not applicable for BERTopic; can return topic assignments or soft topic distributions if required
logger.warning("get_theta is not implemented for BERTopic.")
return None
def export_theta(self):
logger.warning("export_theta is not implemented for BERTopic.")
return None, None