Spaces:
Running
Running
| import numpy as np | |
| from bertopic import BERTopic | |
| from backend.datasets.utils import _utils | |
| from backend.datasets.utils.logger import Logger | |
| logger = Logger("WARNING") | |
| class DBERTopicTrainer: | |
| def __init__(self, | |
| dataset, | |
| num_topics=20, | |
| num_top_words=15, | |
| nr_bins=20, | |
| global_tuning=True, | |
| evolution_tuning=True, | |
| datetime_format=None, | |
| verbose=False): | |
| self.dataset = dataset | |
| self.docs = dataset.raw_documents | |
| self.num_topics=num_topics | |
| # self.timestamps = dataset.train_times | |
| self.vocab = dataset.vocab | |
| self.num_top_words = num_top_words | |
| # self.nr_bins = nr_bins | |
| # self.global_tuning = global_tuning | |
| # self.evolution_tuning = evolution_tuning | |
| # self.datetime_format = datetime_format | |
| self.verbose = verbose | |
| if verbose: | |
| logger.set_level("DEBUG") | |
| else: | |
| logger.set_level("WARNING") | |
| def train(self, timestamps, datetime_format='%Y'): | |
| logger.info("Fitting BERTopic...") | |
| self.model = BERTopic(nr_topics=self.num_topics, verbose=self.verbose) | |
| self.topics, _ = self.model.fit_transform(self.docs) | |
| logger.info("Running topics_over_time...") | |
| self.topics_over_time_df = self.model.topics_over_time( | |
| docs=self.docs, | |
| timestamps=timestamps, | |
| nr_bins=len(set(timestamps)), | |
| datetime_format=datetime_format | |
| ) | |
| self.unique_timestamps = sorted(self.topics_over_time_df["Timestamp"].unique()) | |
| self.unique_topics = sorted(self.topics_over_time_df["Topic"].unique()) | |
| self.vocab = self.model.vectorizer_model.get_feature_names_out() | |
| self.V = len(self.vocab) | |
| self.K = len(self.unique_topics) | |
| self.T = len(self.unique_timestamps) | |
| def get_beta(self): | |
| logger.info("Generating β matrix...") | |
| beta = np.zeros((self.T, self.K, self.V)) | |
| topic_to_index = {topic: idx for idx, topic in enumerate(self.unique_topics)} | |
| timestamp_to_index = {timestamp: idx for idx, timestamp in enumerate(self.unique_timestamps)} | |
| # Extract topic representations at each time | |
| for t_idx, timestamp in enumerate(self.unique_timestamps): | |
| selection = self.topics_over_time_df[self.topics_over_time_df["Timestamp"] == timestamp] | |
| for _, row in selection.iterrows(): | |
| topic = row["Topic"] | |
| words = row["Words"].split(", ") | |
| if topic not in topic_to_index: | |
| continue | |
| k = topic_to_index[topic] | |
| for word in words: | |
| if word in self.vocab: | |
| v = np.where(self.vocab == word)[0][0] | |
| beta[t_idx, k, v] += 1.0 | |
| # Normalize each β_tk to be a probability distribution | |
| beta = beta / (beta.sum(axis=2, keepdims=True) + 1e-10) | |
| return beta | |
| def get_top_words(self, num_top_words=None): | |
| if num_top_words is None: | |
| num_top_words = self.num_top_words | |
| beta = self.get_beta() | |
| top_words_list = list() | |
| for time in range(beta.shape[0]): | |
| top_words = _utils.get_top_words(beta[time], self.vocab, num_top_words, self.verbose) | |
| top_words_list.append(top_words) | |
| return top_words_list | |
| def get_theta(self): | |
| # Not applicable for BERTopic; can return topic assignments or soft topic distributions if required | |
| logger.warning("get_theta is not implemented for BERTopic.") | |
| return None | |
| def export_theta(self): | |
| logger.warning("export_theta is not implemented for BERTopic.") | |
| return None, None | |