| | import torch |
| | from torch import nn |
| | from torch import Tensor, LongTensor |
| |
|
| | from transformers import AutoTokenizer, AutoModel |
| |
|
| |
|
| | class WordTransformerEncoder(nn.Module): |
| | """ |
| | Encodes sentences into word-level embeddings using a pretrained MLM transformer. |
| | """ |
| | def __init__(self, model_name: str): |
| | super().__init__() |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | |
| | self.model = AutoModel.from_pretrained(model_name) |
| |
|
| | def forward(self, words: list[list[str]]) -> Tensor: |
| | """ |
| | Build words embeddings. |
| | |
| | - Tokenizes input sentences into subtokens. |
| | - Passes the subtokens through the pre-trained transformer model. |
| | - Aggregates subtoken embeddings into word embeddings using mean pooling. |
| | """ |
| | batch_size = len(words) |
| |
|
| | |
| | subtokens = self.tokenizer( |
| | words, |
| | padding=True, |
| | truncation=True, |
| | is_split_into_words=True, |
| | return_tensors='pt' |
| | ) |
| | subtokens = subtokens.to(self.model.device) |
| | |
| | |
| | words_ids = torch.stack([ |
| | torch.tensor( |
| | [word_id + 1 if word_id is not None else 0 for word_id in subtokens.word_ids(batch_idx)], |
| | dtype=torch.long, |
| | device=self.model.device |
| | ) |
| | for batch_idx in range(batch_size) |
| | ]) |
| |
|
| | |
| | subtokens_embeddings = self.model(**subtokens).last_hidden_state |
| |
|
| | |
| | |
| | words_emeddings = self._aggregate_subtokens_embeddings(subtokens_embeddings, words_ids) |
| | return words_emeddings |
| |
|
| | def _aggregate_subtokens_embeddings( |
| | self, |
| | subtokens_embeddings: Tensor, |
| | words_ids: LongTensor |
| | ) -> Tensor: |
| | """ |
| | Aggregate subtoken embeddings into word embeddings by averaging. |
| | |
| | This method ensures that multiple subtokens corresponding to a single word are combined |
| | into a single embedding. |
| | """ |
| | batch_size, n_subtokens, embedding_size = subtokens_embeddings.shape |
| | |
| | n_words = torch.max(words_ids) + 1 |
| |
|
| | words_embeddings = torch.zeros( |
| | size=(batch_size, n_words, embedding_size), |
| | dtype=subtokens_embeddings.dtype, |
| | device=self.model.device |
| | ) |
| | words_ids_expanded = words_ids.unsqueeze(-1).expand(batch_size, n_subtokens, embedding_size) |
| |
|
| | |
| | |
| | |
| | words_embeddings.scatter_reduce_( |
| | dim=1, |
| | index=words_ids_expanded, |
| | src=subtokens_embeddings, |
| | reduce="mean", |
| | include_self=False |
| | ) |
| | |
| | words_embeddings = words_embeddings[:, 1:, :] |
| | return words_embeddings |
| |
|
| | def get_embedding_size(self) -> int: |
| | """Returns the embedding size of the transformer model, e.g. 768 for BERT.""" |
| | return self.model.config.hidden_size |
| |
|
| | def get_embeddings_layer(self): |
| | """Returns the embeddings model.""" |
| | return self.model.embeddings |
| | |
| | def get_transformer_layers(self) -> list[nn.Module]: |
| | """ |
| | Return a flat list of all transformer-*block* layers, excluding embeddings/poolers, etc. |
| | """ |
| | layers = [] |
| | for sub in self.model.modules(): |
| | |
| | if isinstance(sub, nn.ModuleList): |
| | layers.extend(list(sub)) |
| | return layers |