import os import spaces from bertopic import BERTopic from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance from hdbscan import HDBSCAN from sentence_transformers import SentenceTransformer from umap import UMAP from src.utils.constants import EMBEDDING_MODEL_NAME, MODEL_REPO_ID from src.utils.utils import get_timestamp HF_TOKEN = os.environ.get("HF_TOKEN", None) embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME) @spaces.GPU(duration=120) def topic_modeling( filename, docs, embeddings, n_gram_range=(3, 6), mmr_diversity=1, mmr_top_n_words=30, keybert_top_n_words=50, random_state=42, min_cluster_size=15, ): """ Perform topic modeling on a list of documents and their embeddings. Parameters ---------- docs : List of str The list of documents to be topic modeled. embeddings : List of numpy.ndarray The list of embeddings of the given documents. embedding_model : SentenceTransformer The embedding model used to generate the embeddings. n_gram_range : Tuple of int, optional The range of n-grams to be considered. Defaults to (3, 6). mmr_diversity : float, optional The diversity value of the MMR model. Defaults to 1. mmr_top_n_words : int, optional The number of top words to be considered in the MMR model. Defaults to 30. keybert_top_n_words : int, optional The number of top words to be considered in the KeyBERT model. Defaults to 50. random_state : int, optional The random seed for reproducibility. Defaults to 42. min_cluster_size : int, optional The minimum size of a cluster to be considered as a topic. Defaults to 15. Returns ------- fig : matplotlib.figure.Figure The datamap of the topic modeling. topic_info_df : pandas.DataFrame The topic information dataframe. """ representation_model = [ KeyBERTInspired(top_n_words=keybert_top_n_words, random_state=random_state), MaximalMarginalRelevance(diversity=mmr_diversity, top_n_words=mmr_top_n_words), ] hdbscan_model = HDBSCAN( min_cluster_size=min_cluster_size, metric="euclidean", cluster_selection_method="eom", prediction_data=True, ) umap_model = UMAP( n_neighbors=15, n_components=5, min_dist=0.0, metric="cosine", low_memory=False, random_state=random_state, ) topic_model = BERTopic( embedding_model=embedding_model, representation_model=representation_model, n_gram_range=n_gram_range, hdbscan_model=hdbscan_model, umap_model=umap_model, verbose=True, ).fit(docs, embeddings=embeddings) topic_model.push_to_hf_hub( repo_id=MODEL_REPO_ID, commit_message=f"{get_timestamp()} - {filename}", token=HF_TOKEN, private=True, serialization="safetensors", save_embedding_model=EMBEDDING_MODEL_NAME, save_ctfidf=True, ) topic_info_df = topic_model.get_topic_info() return topic_info_df