nlomov's picture
Added all source code
ab33b80
Raw
History Blame Contribute Delete
8.08 kB
import numpy as np
from sentence_transformers import SentenceTransformer
import scipy.sparse
import warnings
from contextualized_topic_models.datasets.dataset import CTMDataset
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import OneHotEncoder
def get_bag_of_words(data, min_length):
"""
Creates the bag of words
"""
vect = [
np.bincount(x[x != np.array(None)].astype("int"), minlength=min_length)
for x in data
if np.sum(x[x != np.array(None)]) != 0
]
vect = scipy.sparse.csr_matrix(vect)
return vect
def bert_embeddings_from_file(
text_file, sbert_model_to_load, batch_size=64, max_seq_length=None
):
"""
Creates SBERT Embeddings from an input file, assumes one document per line
"""
model = SentenceTransformer(sbert_model_to_load)
if max_seq_length is not None:
model.max_seq_length = max_seq_length
with open(text_file, encoding="utf-8") as filino:
texts = list(map(lambda x: x, filino.readlines()))
check_max_local_length(max_seq_length, texts)
return np.array(model.encode(texts, show_progress_bar=True, batch_size=batch_size))
def bert_embeddings_from_list(
texts, sbert_model_to_load, batch_size=64, max_seq_length=None
):
"""
Creates SBERT Embeddings from a list
"""
model = SentenceTransformer(sbert_model_to_load)
if max_seq_length is not None:
model.max_seq_length = max_seq_length
check_max_local_length(max_seq_length, texts)
return np.array(model.encode(texts, show_progress_bar=True, batch_size=batch_size))
def check_max_local_length(max_seq_length, texts):
max_local_length = np.max([len(t.split()) for t in texts])
if max_local_length > max_seq_length:
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
f"the longest document in your collection has {max_local_length} words, the model instead "
f"truncates to {max_seq_length} tokens."
)
class TopicModelDataPreparation:
def __init__(
self, contextualized_model=None, show_warning=True, max_seq_length=128
):
self.contextualized_model = contextualized_model
self.vocab = []
self.id2token = {}
self.vectorizer = None
self.label_encoder = None
self.show_warning = show_warning
self.max_seq_length = max_seq_length
def load(self, contextualized_embeddings, bow_embeddings, id2token, labels=None):
return CTMDataset(
X_contextual=contextualized_embeddings,
X_bow=bow_embeddings,
idx2token=id2token,
labels=labels,
)
def fit(
self, text_for_contextual, text_for_bow, labels=None, custom_embeddings=None
):
"""
This method fits the vectorizer and gets the embeddings from the contextual model
:param text_for_contextual: list of unpreprocessed documents to generate the contextualized embeddings
:param text_for_bow: list of preprocessed documents for creating the bag-of-words
:param custom_embeddings: np.ndarray type object to use custom embeddings (optional).
:param labels: list of labels associated with each document (optional).
"""
if custom_embeddings is not None:
assert len(text_for_contextual) == len(custom_embeddings)
if text_for_bow is not None:
assert len(custom_embeddings) == len(text_for_bow)
if type(custom_embeddings).__module__ != "numpy":
raise TypeError(
"contextualized_embeddings must be a numpy.ndarray type object"
)
if text_for_bow is not None:
assert len(text_for_contextual) == len(text_for_bow)
if self.contextualized_model is None and custom_embeddings is None:
raise Exception(
"A contextualized model or contextualized embeddings must be defined"
)
# TODO: this count vectorizer removes tokens that have len = 1, might be unexpected for the users
self.vectorizer = CountVectorizer()
train_bow_embeddings = self.vectorizer.fit_transform(text_for_bow)
# if the user is passing custom embeddings we don't need to create the embeddings using the model
if custom_embeddings is None:
train_contextualized_embeddings = bert_embeddings_from_list(
text_for_contextual,
sbert_model_to_load=self.contextualized_model,
max_seq_length=self.max_seq_length,
)
else:
train_contextualized_embeddings = custom_embeddings
self.vocab = self.vectorizer.get_feature_names_out()
self.id2token = {k: v for k, v in zip(range(0, len(self.vocab)), self.vocab)}
if labels:
self.label_encoder = OneHotEncoder()
encoded_labels = self.label_encoder.fit_transform(
np.array([labels]).reshape(-1, 1)
)
else:
encoded_labels = None
return CTMDataset(
X_contextual=train_contextualized_embeddings,
X_bow=train_bow_embeddings,
idx2token=self.id2token,
labels=encoded_labels,
)
def transform(
self,
text_for_contextual,
text_for_bow=None,
custom_embeddings=None,
labels=None,
):
"""
This method create the input for the prediction. Essentially, it creates the embeddings with the contextualized
model of choice and with trained vectorizer.
If text_for_bow is missing, it should be because we are using ZeroShotTM
:param text_for_contextual: list of unpreprocessed documents to generate the contextualized embeddings
:param text_for_bow: list of preprocessed documents for creating the bag-of-words
:param custom_embeddings: np.ndarray type object to use custom embeddings (optional).
:param labels: list of labels associated with each document (optional).
"""
if custom_embeddings is not None:
assert len(text_for_contextual) == len(custom_embeddings)
if text_for_bow is not None:
assert len(custom_embeddings) == len(text_for_bow)
if text_for_bow is not None:
assert len(text_for_contextual) == len(text_for_bow)
if self.contextualized_model is None:
raise Exception(
"You should define a contextualized model if you want to create the embeddings"
)
if text_for_bow is not None:
test_bow_embeddings = self.vectorizer.transform(text_for_bow)
else:
# dummy matrix
if self.show_warning:
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"The method did not have in input the text_for_bow parameter. This IS EXPECTED if you "
"are using ZeroShotTM in a cross-lingual setting"
)
# we just need an object that is matrix-like so that pytorch does not complain
test_bow_embeddings = scipy.sparse.csr_matrix(
np.zeros((len(text_for_contextual), 1))
)
if custom_embeddings is None:
test_contextualized_embeddings = bert_embeddings_from_list(
text_for_contextual,
sbert_model_to_load=self.contextualized_model,
max_seq_length=self.max_seq_length,
)
else:
test_contextualized_embeddings = custom_embeddings
if labels:
encoded_labels = self.label_encoder.transform(
np.array([labels]).reshape(-1, 1)
)
else:
encoded_labels = None
return CTMDataset(
X_contextual=test_contextualized_embeddings,
X_bow=test_bow_embeddings,
idx2token=self.id2token,
labels=encoded_labels,
)