|
|
import torch |
|
|
from torch import nn |
|
|
from torch import Tensor, LongTensor |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
|
|
|
class WordTransformerEncoder(nn.Module): |
|
|
""" |
|
|
Encodes sentences into word-level embeddings using a pretrained MLM transformer. |
|
|
""" |
|
|
def __init__(self, model_name: str): |
|
|
super().__init__() |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
self.model = AutoModel.from_pretrained(model_name) |
|
|
|
|
|
def forward(self, words: list[list[str]]) -> Tensor: |
|
|
""" |
|
|
Build words embeddings. |
|
|
|
|
|
- Tokenizes input sentences into subtokens. |
|
|
- Passes the subtokens through the pre-trained transformer model. |
|
|
- Aggregates subtoken embeddings into word embeddings using mean pooling. |
|
|
""" |
|
|
batch_size = len(words) |
|
|
|
|
|
|
|
|
subtokens = self.tokenizer( |
|
|
words, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
is_split_into_words=True, |
|
|
return_tensors='pt' |
|
|
) |
|
|
subtokens = subtokens.to(self.model.device) |
|
|
|
|
|
|
|
|
words_ids = torch.stack([ |
|
|
torch.tensor( |
|
|
[word_id + 1 if word_id is not None else 0 for word_id in subtokens.word_ids(batch_idx)], |
|
|
dtype=torch.long, |
|
|
device=self.model.device |
|
|
) |
|
|
for batch_idx in range(batch_size) |
|
|
]) |
|
|
|
|
|
|
|
|
subtokens_embeddings = self.model(**subtokens).last_hidden_state |
|
|
|
|
|
|
|
|
|
|
|
words_emeddings = self._aggregate_subtokens_embeddings(subtokens_embeddings, words_ids) |
|
|
return words_emeddings |
|
|
|
|
|
def _aggregate_subtokens_embeddings( |
|
|
self, |
|
|
subtokens_embeddings: Tensor, |
|
|
words_ids: LongTensor |
|
|
) -> Tensor: |
|
|
""" |
|
|
Aggregate subtoken embeddings into word embeddings by averaging. |
|
|
|
|
|
This method ensures that multiple subtokens corresponding to a single word are combined |
|
|
into a single embedding. |
|
|
""" |
|
|
batch_size, n_subtokens, embedding_size = subtokens_embeddings.shape |
|
|
|
|
|
n_words = torch.max(words_ids) + 1 |
|
|
|
|
|
words_embeddings = torch.zeros( |
|
|
size=(batch_size, n_words, embedding_size), |
|
|
dtype=subtokens_embeddings.dtype, |
|
|
device=self.model.device |
|
|
) |
|
|
words_ids_expanded = words_ids.unsqueeze(-1).expand(batch_size, n_subtokens, embedding_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
words_embeddings.scatter_reduce_( |
|
|
dim=1, |
|
|
index=words_ids_expanded, |
|
|
src=subtokens_embeddings, |
|
|
reduce="mean", |
|
|
include_self=False |
|
|
) |
|
|
|
|
|
words_embeddings = words_embeddings[:, 1:, :] |
|
|
return words_embeddings |
|
|
|
|
|
def get_embedding_size(self) -> int: |
|
|
"""Returns the embedding size of the transformer model, e.g. 768 for BERT.""" |
|
|
return self.model.config.hidden_size |
|
|
|
|
|
def get_embeddings_layer(self): |
|
|
"""Returns the embeddings model.""" |
|
|
return self.model.embeddings |
|
|
|
|
|
def get_transformer_layers(self) -> list[nn.Module]: |
|
|
""" |
|
|
Return a flat list of all transformer-*block* layers, excluding embeddings/poolers, etc. |
|
|
""" |
|
|
layers = [] |
|
|
for sub in self.model.modules(): |
|
|
|
|
|
if isinstance(sub, nn.ModuleList): |
|
|
layers.extend(list(sub)) |
|
|
return layers |