Bertopic / src /scripts /topic_modeling.py
Dopler47's picture
Error fix
a6dee29
import os
import spaces
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance
from hdbscan import HDBSCAN
from sentence_transformers import SentenceTransformer
from umap import UMAP
from src.utils.constants import EMBEDDING_MODEL_NAME, MODEL_REPO_ID
from src.utils.utils import get_timestamp
HF_TOKEN = os.environ.get("HF_TOKEN", None)
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
@spaces.GPU(duration=120)
def topic_modeling(
filename,
docs,
embeddings,
n_gram_range=(3, 6),
mmr_diversity=1,
mmr_top_n_words=30,
keybert_top_n_words=50,
random_state=42,
min_cluster_size=15,
):
"""
Perform topic modeling on a list of documents and their embeddings.
Parameters
----------
docs : List of str
The list of documents to be topic modeled.
embeddings : List of numpy.ndarray
The list of embeddings of the given documents.
embedding_model : SentenceTransformer
The embedding model used to generate the embeddings.
n_gram_range : Tuple of int, optional
The range of n-grams to be considered. Defaults to (3, 6).
mmr_diversity : float, optional
The diversity value of the MMR model. Defaults to 1.
mmr_top_n_words : int, optional
The number of top words to be considered in the MMR model. Defaults to 30.
keybert_top_n_words : int, optional
The number of top words to be considered in the KeyBERT model. Defaults to 50.
random_state : int, optional
The random seed for reproducibility. Defaults to 42.
min_cluster_size : int, optional
The minimum size of a cluster to be considered as a topic. Defaults to 15.
Returns
-------
fig : matplotlib.figure.Figure
The datamap of the topic modeling.
topic_info_df : pandas.DataFrame
The topic information dataframe.
"""
representation_model = [
KeyBERTInspired(top_n_words=keybert_top_n_words, random_state=random_state),
MaximalMarginalRelevance(diversity=mmr_diversity, top_n_words=mmr_top_n_words),
]
hdbscan_model = HDBSCAN(
min_cluster_size=min_cluster_size,
metric="euclidean",
cluster_selection_method="eom",
prediction_data=True,
)
umap_model = UMAP(
n_neighbors=15,
n_components=5,
min_dist=0.0,
metric="cosine",
low_memory=False,
random_state=random_state,
)
topic_model = BERTopic(
embedding_model=embedding_model,
representation_model=representation_model,
n_gram_range=n_gram_range,
hdbscan_model=hdbscan_model,
umap_model=umap_model,
verbose=True,
).fit(docs, embeddings=embeddings)
topic_model.push_to_hf_hub(
repo_id=MODEL_REPO_ID,
commit_message=f"{get_timestamp()} - {filename}",
token=HF_TOKEN,
private=True,
serialization="safetensors",
save_embedding_model=EMBEDDING_MODEL_NAME,
save_ctfidf=True,
)
topic_info_df = topic_model.get_topic_info()
return topic_info_df