import numpy as np from sentence_transformers import SentenceTransformer import scipy.sparse import warnings from contextualized_topic_models.datasets.dataset import CTMDataset from sklearn.feature_extraction.text import CountVectorizer from sklearn.preprocessing import OneHotEncoder def get_bag_of_words(data, min_length): """ Creates the bag of words """ vect = [ np.bincount(x[x != np.array(None)].astype("int"), minlength=min_length) for x in data if np.sum(x[x != np.array(None)]) != 0 ] vect = scipy.sparse.csr_matrix(vect) return vect def bert_embeddings_from_file( text_file, sbert_model_to_load, batch_size=64, max_seq_length=None ): """ Creates SBERT Embeddings from an input file, assumes one document per line """ model = SentenceTransformer(sbert_model_to_load) if max_seq_length is not None: model.max_seq_length = max_seq_length with open(text_file, encoding="utf-8") as filino: texts = list(map(lambda x: x, filino.readlines())) check_max_local_length(max_seq_length, texts) return np.array(model.encode(texts, show_progress_bar=True, batch_size=batch_size)) def bert_embeddings_from_list( texts, sbert_model_to_load, batch_size=64, max_seq_length=None ): """ Creates SBERT Embeddings from a list """ model = SentenceTransformer(sbert_model_to_load) if max_seq_length is not None: model.max_seq_length = max_seq_length check_max_local_length(max_seq_length, texts) return np.array(model.encode(texts, show_progress_bar=True, batch_size=batch_size)) def check_max_local_length(max_seq_length, texts): max_local_length = np.max([len(t.split()) for t in texts]) if max_local_length > max_seq_length: warnings.simplefilter("always", DeprecationWarning) warnings.warn( f"the longest document in your collection has {max_local_length} words, the model instead " f"truncates to {max_seq_length} tokens." ) class TopicModelDataPreparation: def __init__( self, contextualized_model=None, show_warning=True, max_seq_length=128 ): self.contextualized_model = contextualized_model self.vocab = [] self.id2token = {} self.vectorizer = None self.label_encoder = None self.show_warning = show_warning self.max_seq_length = max_seq_length def load(self, contextualized_embeddings, bow_embeddings, id2token, labels=None): return CTMDataset( X_contextual=contextualized_embeddings, X_bow=bow_embeddings, idx2token=id2token, labels=labels, ) def fit( self, text_for_contextual, text_for_bow, labels=None, custom_embeddings=None ): """ This method fits the vectorizer and gets the embeddings from the contextual model :param text_for_contextual: list of unpreprocessed documents to generate the contextualized embeddings :param text_for_bow: list of preprocessed documents for creating the bag-of-words :param custom_embeddings: np.ndarray type object to use custom embeddings (optional). :param labels: list of labels associated with each document (optional). """ if custom_embeddings is not None: assert len(text_for_contextual) == len(custom_embeddings) if text_for_bow is not None: assert len(custom_embeddings) == len(text_for_bow) if type(custom_embeddings).__module__ != "numpy": raise TypeError( "contextualized_embeddings must be a numpy.ndarray type object" ) if text_for_bow is not None: assert len(text_for_contextual) == len(text_for_bow) if self.contextualized_model is None and custom_embeddings is None: raise Exception( "A contextualized model or contextualized embeddings must be defined" ) # TODO: this count vectorizer removes tokens that have len = 1, might be unexpected for the users self.vectorizer = CountVectorizer() train_bow_embeddings = self.vectorizer.fit_transform(text_for_bow) # if the user is passing custom embeddings we don't need to create the embeddings using the model if custom_embeddings is None: train_contextualized_embeddings = bert_embeddings_from_list( text_for_contextual, sbert_model_to_load=self.contextualized_model, max_seq_length=self.max_seq_length, ) else: train_contextualized_embeddings = custom_embeddings self.vocab = self.vectorizer.get_feature_names_out() self.id2token = {k: v for k, v in zip(range(0, len(self.vocab)), self.vocab)} if labels: self.label_encoder = OneHotEncoder() encoded_labels = self.label_encoder.fit_transform( np.array([labels]).reshape(-1, 1) ) else: encoded_labels = None return CTMDataset( X_contextual=train_contextualized_embeddings, X_bow=train_bow_embeddings, idx2token=self.id2token, labels=encoded_labels, ) def transform( self, text_for_contextual, text_for_bow=None, custom_embeddings=None, labels=None, ): """ This method create the input for the prediction. Essentially, it creates the embeddings with the contextualized model of choice and with trained vectorizer. If text_for_bow is missing, it should be because we are using ZeroShotTM :param text_for_contextual: list of unpreprocessed documents to generate the contextualized embeddings :param text_for_bow: list of preprocessed documents for creating the bag-of-words :param custom_embeddings: np.ndarray type object to use custom embeddings (optional). :param labels: list of labels associated with each document (optional). """ if custom_embeddings is not None: assert len(text_for_contextual) == len(custom_embeddings) if text_for_bow is not None: assert len(custom_embeddings) == len(text_for_bow) if text_for_bow is not None: assert len(text_for_contextual) == len(text_for_bow) if self.contextualized_model is None: raise Exception( "You should define a contextualized model if you want to create the embeddings" ) if text_for_bow is not None: test_bow_embeddings = self.vectorizer.transform(text_for_bow) else: # dummy matrix if self.show_warning: warnings.simplefilter("always", DeprecationWarning) warnings.warn( "The method did not have in input the text_for_bow parameter. This IS EXPECTED if you " "are using ZeroShotTM in a cross-lingual setting" ) # we just need an object that is matrix-like so that pytorch does not complain test_bow_embeddings = scipy.sparse.csr_matrix( np.zeros((len(text_for_contextual), 1)) ) if custom_embeddings is None: test_contextualized_embeddings = bert_embeddings_from_list( text_for_contextual, sbert_model_to_load=self.contextualized_model, max_seq_length=self.max_seq_length, ) else: test_contextualized_embeddings = custom_embeddings if labels: encoded_labels = self.label_encoder.transform( np.array([labels]).reshape(-1, 1) ) else: encoded_labels = None return CTMDataset( X_contextual=test_contextualized_embeddings, X_bow=test_bow_embeddings, idx2token=self.id2token, labels=encoded_labels, )